VaultData should come directly from the database (#425)

This commit is contained in:
David Perez 2023-12-21 08:58:22 -06:00 committed by Álison Fernandes
parent b2692a5637
commit f2842446c9
4 changed files with 503 additions and 785 deletions

View file

@ -38,3 +38,102 @@ fun <T : Any?> MutableStateFlow<DataState<T>>.updateToPendingOrLoading() {
?: DataState.Loading
}
}
/**
* Combines the [dataState1] and [dataState2] [DataState]s together using the provided [transform].
*
* This function will internally manage the final `DataState` type that is returned. This is done
* by prioritizing each if the states in this order:
*
* - [DataState.Error]
* - [DataState.NoNetwork]
* - [DataState.Loading]
* - [DataState.Pending]
* - [DataState.Loaded]
*
* This priority order ensures that the total state is accurately reflecting the underlying states.
* If one of the `DataState`s has a higher priority than the other, the output will be the highest
* priority. For example, if one `DataState` is `DataState.Loaded` and another is `DataState.Error`
* then the returned `DataState` will be `DataState.Error`.
*
* The payload of the final `DataState` be created by the `transform` lambda which will be invoked
* whenever the payloads of both `dataState1` and `dataState2` are not null. In the scenario where
* one or both payloads are null, the resulting `DataState` will have a null payload.
*/
fun <T1, T2, R> combineDataStates(
dataState1: DataState<T1>,
dataState2: DataState<T2>,
transform: (t1: T1, t2: T2) -> R,
): DataState<R> {
// Wraps the `transform` lambda to allow null data to be passed in. If either of the passed in
// values are null, the regular transform will not be invoked and null is returned.
val nullableTransform: (T1?, T2?) -> R? = { t1, t2 ->
if (t1 != null && t2 != null) transform(t1, t2) else null
}
return when {
// Error states have highest priority, fail fast.
dataState1 is DataState.Error -> {
DataState.Error(
error = dataState1.error,
data = nullableTransform(dataState1.data, dataState2.data),
)
}
dataState2 is DataState.Error -> {
DataState.Error(
error = dataState2.error,
data = nullableTransform(dataState1.data, dataState2.data),
)
}
dataState1 is DataState.NoNetwork || dataState2 is DataState.NoNetwork -> {
DataState.NoNetwork(nullableTransform(dataState1.data, dataState2.data))
}
// Something is still loading, we will wait for all the data.
dataState1 is DataState.Loading || dataState2 is DataState.Loading -> DataState.Loading
// Pending state for everything while any one piece of data is updating.
dataState1 is DataState.Pending || dataState2 is DataState.Pending -> {
DataState.Pending(
transform(requireNotNull(dataState1.data), requireNotNull(dataState2.data)),
)
}
// Both states are Loaded and have data
else -> {
DataState.Loaded(
transform(requireNotNull(dataState1.data), requireNotNull(dataState2.data)),
)
}
}
}
/**
* Combines the [dataState1], [dataState2], and [dataState3] [DataState]s together using the
* provided [transform].
*
* See [combineDataStates] for details.
*/
fun <T1, T2, T3, R> combineDataStates(
dataState1: DataState<T1>,
dataState2: DataState<T2>,
dataState3: DataState<T3>,
transform: (t1: T1, t2: T2, t3: T3) -> R,
): DataState<R> =
dataState1
.combineDataStatesWith(dataState2) { t1, t2 -> t1 to t2 }
.combineDataStatesWith(dataState3) { t1t2Pair, t3 ->
transform(t1t2Pair.first, t1t2Pair.second, t3)
}
/**
* Combines [dataState2] with the given [DataState] using the provided [transform].
*
* See [combineDataStates] for details.
*/
fun <T1, T2, R> DataState<T1>.combineDataStatesWith(
dataState2: DataState<T2>,
transform: (t1: T1, t2: T2) -> R,
): DataState<R> =
combineDataStates(this, dataState2, transform)

View file

@ -13,12 +13,12 @@ import com.x8bit.bitwarden.data.auth.repository.util.toUpdatedUserStateJson
import com.x8bit.bitwarden.data.platform.datasource.network.util.isNoConnectionError
import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager
import com.x8bit.bitwarden.data.platform.repository.model.DataState
import com.x8bit.bitwarden.data.platform.repository.util.combineDataStates
import com.x8bit.bitwarden.data.platform.repository.util.map
import com.x8bit.bitwarden.data.platform.repository.util.observeWhenSubscribedAndLoggedIn
import com.x8bit.bitwarden.data.platform.repository.util.updateToPendingOrLoading
import com.x8bit.bitwarden.data.platform.util.asSuccess
import com.x8bit.bitwarden.data.platform.util.flatMap
import com.x8bit.bitwarden.data.platform.util.zip
import com.x8bit.bitwarden.data.vault.datasource.disk.VaultDiskSource
import com.x8bit.bitwarden.data.vault.datasource.network.model.SyncResponseJson
import com.x8bit.bitwarden.data.vault.datasource.network.service.CiphersService
@ -39,12 +39,12 @@ import com.x8bit.bitwarden.data.vault.repository.util.toEncryptedSdkSendList
import com.x8bit.bitwarden.data.vault.repository.util.toVaultUnlockResult
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.async
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.combine
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.launchIn
@ -55,7 +55,12 @@ import kotlinx.coroutines.flow.onStart
import kotlinx.coroutines.flow.stateIn
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
/**
* A "stop timeout delay" in milliseconds used to let a shared coroutine continue to run for the
* specified period of time after it no longer has subscribers.
*/
private const val STOP_TIMEOUT_DELAY_MS: Long = 1000L
/**
* Default implementation of [VaultRepository].
@ -67,10 +72,11 @@ class VaultRepositoryImpl(
private val vaultDiskSource: VaultDiskSource,
private val vaultSdkSource: VaultSdkSource,
private val authDiskSource: AuthDiskSource,
private val dispatcherManager: DispatcherManager,
dispatcherManager: DispatcherManager,
) : VaultRepository {
private val scope = CoroutineScope(dispatcherManager.io)
private val unconfinedScope = CoroutineScope(dispatcherManager.unconfined)
private val ioScope = CoroutineScope(dispatcherManager.io)
private var syncJob: Job = Job().apply { complete() }
@ -78,9 +84,6 @@ class VaultRepositoryImpl(
private val activeUserId: String? get() = authDiskSource.userState?.activeUserId
private val mutableVaultDataStateFlow =
MutableStateFlow<DataState<VaultData>>(DataState.Loading)
private val mutableVaultStateStateFlow =
MutableStateFlow(VaultState(unlockedVaultUserIds = emptySet()))
@ -95,8 +98,29 @@ class VaultRepositoryImpl(
private val mutableCollectionsStateFlow =
MutableStateFlow<DataState<List<CollectionView>>>(DataState.Loading)
override val vaultDataStateFlow: StateFlow<DataState<VaultData>>
get() = mutableVaultDataStateFlow.asStateFlow()
override val vaultDataStateFlow: StateFlow<DataState<VaultData>> =
combine(
ciphersStateFlow,
foldersStateFlow,
collectionsStateFlow,
) { ciphersDataState, foldersDataState, collectionsDataState ->
combineDataStates(
ciphersDataState,
foldersDataState,
collectionsDataState,
) { ciphersData, foldersData, collectionsData ->
VaultData(
cipherViewList = ciphersData,
folderViewList = foldersData,
collectionViewList = collectionsData,
)
}
}
.stateIn(
scope = unconfinedScope,
started = SharingStarted.WhileSubscribed(stopTimeoutMillis = STOP_TIMEOUT_DELAY_MS),
initialValue = DataState.Loading,
)
override val ciphersStateFlow: StateFlow<DataState<List<CipherView>>>
get() = mutableCiphersStateFlow.asStateFlow()
@ -119,31 +143,30 @@ class VaultRepositoryImpl(
.observeWhenSubscribedAndLoggedIn(authDiskSource.userStateFlow) { activeUserId ->
observeVaultDiskCiphers(activeUserId)
}
.launchIn(scope)
.launchIn(unconfinedScope)
// Setup folders MutableStateFlow
mutableFoldersStateFlow
.observeWhenSubscribedAndLoggedIn(authDiskSource.userStateFlow) { activeUserId ->
observeVaultDiskFolders(activeUserId)
}
.launchIn(scope)
.launchIn(unconfinedScope)
// Setup collections MutableStateFlow
mutableCollectionsStateFlow
.observeWhenSubscribedAndLoggedIn(authDiskSource.userStateFlow) { activeUserId ->
observeVaultDiskCollections(activeUserId)
}
.launchIn(scope)
.launchIn(unconfinedScope)
}
override fun clearUnlockedData() {
mutableCiphersStateFlow.update { DataState.Loading }
mutableFoldersStateFlow.update { DataState.Loading }
mutableCollectionsStateFlow.update { DataState.Loading }
mutableVaultDataStateFlow.update { DataState.Loading }
mutableSendDataStateFlow.update { DataState.Loading }
}
override fun deleteVaultData(userId: String) {
scope.launch {
ioScope.launch {
vaultDiskSource.deleteVaultData(userId)
}
}
@ -154,9 +177,8 @@ class VaultRepositoryImpl(
mutableCiphersStateFlow.updateToPendingOrLoading()
mutableFoldersStateFlow.updateToPendingOrLoading()
mutableCollectionsStateFlow.updateToPendingOrLoading()
mutableVaultDataStateFlow.updateToPendingOrLoading()
mutableSendDataStateFlow.updateToPendingOrLoading()
syncJob = scope.launch {
syncJob = ioScope.launch {
syncService
.sync()
.fold(
@ -170,10 +192,7 @@ class VaultRepositoryImpl(
unlockVaultForOrganizationsIfNecessary(syncResponse = syncResponse)
storeKeys(syncResponse = syncResponse)
decryptSyncResponseAndUpdateVaultDataState(
userId = userId,
syncResponse = syncResponse,
)
vaultDiskSource.replaceVaultData(userId = userId, vault = syncResponse)
decryptSendsAndUpdateSendDataState(sendList = syncResponse.sends)
},
onFailure = { throwable ->
@ -192,11 +211,6 @@ class VaultRepositoryImpl(
data = currentState.data,
)
}
mutableVaultDataStateFlow.update { currentState ->
throwable.toNetworkOrErrorState(
data = currentState.data,
)
}
mutableSendDataStateFlow.update { currentState ->
throwable.toNetworkOrErrorState(
data = currentState.data,
@ -217,7 +231,7 @@ class VaultRepositoryImpl(
}
}
.stateIn(
scope = scope,
scope = unconfinedScope,
started = SharingStarted.Lazily,
initialValue = DataState.Loading,
)
@ -232,7 +246,7 @@ class VaultRepositoryImpl(
}
}
.stateIn(
scope = scope,
scope = unconfinedScope,
started = SharingStarted.Lazily,
initialValue = DataState.Loading,
)
@ -450,58 +464,6 @@ class VaultRepositoryImpl(
mutableSendDataStateFlow.update { newState }
}
private suspend fun decryptSyncResponseAndUpdateVaultDataState(
userId: String,
syncResponse: SyncResponseJson,
) = withContext(dispatcherManager.default) {
val deferred = async {
vaultDiskSource.replaceVaultData(userId = userId, vault = syncResponse)
}
// Allow decryption of various types in parallel.
val newState = zip(
{
vaultSdkSource
.decryptCipherList(
cipherList = syncResponse
.ciphers
.orEmpty()
.toEncryptedSdkCipherList(),
)
},
{
vaultSdkSource
.decryptFolderList(
folderList = syncResponse
.folders
.orEmpty()
.toEncryptedSdkFolderList(),
)
},
{
vaultSdkSource
.decryptCollectionList(
collectionList = syncResponse
.collections
.orEmpty()
.toEncryptedSdkCollectionList(),
)
},
) { decryptedCipherList, decryptedFolderList, decryptedCollectionList ->
VaultData(
cipherViewList = decryptedCipherList,
collectionViewList = decryptedCollectionList,
folderViewList = decryptedFolderList,
)
}
.fold(
onSuccess = { DataState.Loaded(data = it) },
onFailure = { DataState.Error(error = it) },
)
mutableVaultDataStateFlow.update { newState }
deferred.await()
}
private fun observeVaultDiskCiphers(
userId: String,
): Flow<DataState<List<CipherView>>> =

View file

@ -41,4 +41,211 @@ class DataStateExtensionsTest {
assertEquals(DataState.Loading, mutableStateFlow.value)
}
@Suppress("MaxLineLength")
@Test
fun `combineDataStates should return an empty Error when the first dataState is Error without data`() {
val throwable = Throwable("Fail")
val dataState1 = DataState.Error<String>(throwable)
val dataState2 = DataState.Loaded(5)
val result = combineDataStates(
dataState1 = dataState1,
dataState2 = dataState2,
) { data1, data2 ->
data1 to data2
}
assertEquals(DataState.Error<Pair<String, Int>>(throwable), result)
}
@Suppress("MaxLineLength")
@Test
fun `combineDataStates should return a populated Error when the first dataState is Error with data`() {
val throwable = Throwable("Fail")
val dataState1 = DataState.Error(throwable, "data")
val dataState2 = DataState.Loaded(5)
val result = combineDataStates(
dataState1 = dataState1,
dataState2 = dataState2,
) { data1, data2 ->
data1 to data2
}
assertEquals(DataState.Error(throwable, "data" to 5), result)
}
@Suppress("MaxLineLength")
@Test
fun `combineDataStates should return an empty Error when the second dataState is Error without data`() {
val throwable = Throwable("Fail")
val dataState1 = DataState.Loaded(5)
val dataState2 = DataState.Error<String>(throwable)
val result = combineDataStates(
dataState1 = dataState1,
dataState2 = dataState2,
) { data1, data2 ->
data1 to data2
}
assertEquals(DataState.Error<Pair<Int, String>>(throwable), result)
}
@Suppress("MaxLineLength")
@Test
fun `combineDataStates should return a populated Error when the second dataState is Error with data`() {
val throwable = Throwable("Fail")
val dataState1 = DataState.Loaded(5)
val dataState2 = DataState.Error(throwable, "data")
val result = combineDataStates(
dataState1 = dataState1,
dataState2 = dataState2,
) { data1, data2 ->
data1 to data2
}
assertEquals(DataState.Error(throwable, 5 to "data"), result)
}
@Suppress("MaxLineLength")
@Test
fun `combineDataStates should return an empty NoNetwork when the first dataState is NoNetwork without data`() {
val dataState1 = DataState.NoNetwork<Int>()
val dataState2 = DataState.Loaded("data")
val result = combineDataStates(
dataState1 = dataState1,
dataState2 = dataState2,
) { data1, data2 ->
data1 to data2
}
assertEquals(DataState.NoNetwork<Pair<Int, String>>(), result)
}
@Suppress("MaxLineLength")
@Test
fun `combineDataStates should return a populated NoNetwork when the first dataState is NoNetwork with data`() {
val dataState1 = DataState.NoNetwork(5)
val dataState2 = DataState.Loaded("data")
val result = combineDataStates(
dataState1 = dataState1,
dataState2 = dataState2,
) { data1, data2 ->
data1 to data2
}
assertEquals(DataState.NoNetwork(5 to "data"), result)
}
@Suppress("MaxLineLength")
@Test
fun `combineDataStates should return an empty NoNetwork when the second dataState is NoNetwork without data`() {
val dataState1 = DataState.Loaded("data")
val dataState2 = DataState.NoNetwork<Int>()
val result = combineDataStates(
dataState1 = dataState1,
dataState2 = dataState2,
) { data1, data2 ->
data1 to data2
}
assertEquals(DataState.NoNetwork<Pair<String, Int>>(), result)
}
@Suppress("MaxLineLength")
@Test
fun `combineDataStates should return a populated NoNetwork when the second dataState is NoNetwork with data`() {
val dataState1 = DataState.Loaded("data")
val dataState2 = DataState.NoNetwork(5)
val result = combineDataStates(
dataState1 = dataState1,
dataState2 = dataState2,
) { data1, data2 ->
data1 to data2
}
assertEquals(DataState.NoNetwork("data" to 5), result)
}
@Test
fun `combineDataStates should return Loading when the first dataState is Loading`() {
val dataState1: DataState<Int> = DataState.Loading
val dataState2 = DataState.Loaded("data")
val result = combineDataStates(
dataState1 = dataState1,
dataState2 = dataState2,
) { data1, data2 ->
data1 to data2
}
assertEquals(DataState.Loading, result)
}
@Test
fun `combineDataStates should return Loading when the second dataState is Loading`() {
val dataState1 = DataState.Loaded("data")
val dataState2: DataState<Int> = DataState.Loading
val result = combineDataStates(
dataState1 = dataState1,
dataState2 = dataState2,
) { data1, data2 ->
data1 to data2
}
assertEquals(DataState.Loading, result)
}
@Test
fun `combineDataStates should return Pending when the first dataState is Pending`() {
val dataState1 = DataState.Pending(5)
val dataState2 = DataState.Loaded("data")
val result = combineDataStates(
dataState1 = dataState1,
dataState2 = dataState2,
) { data1, data2 ->
data1 to data2
}
assertEquals(DataState.Pending(5 to "data"), result)
}
@Test
fun `combineDataStates should return Pending when the second dataState is Pending`() {
val dataState1 = DataState.Loaded("data")
val dataState2 = DataState.Pending(5)
val result = combineDataStates(
dataState1 = dataState1,
dataState2 = dataState2,
) { data1, data2 ->
data1 to data2
}
assertEquals(DataState.Pending("data" to 5), result)
}
@Test
fun `combineDataStates should return Loaded when the both dataStates are Loaded`() {
val dataState1 = DataState.Loaded("data")
val dataState2 = DataState.Loaded(5)
val result = combineDataStates(
dataState1 = dataState1,
dataState2 = dataState2,
) { data1, data2 ->
data1 to data2
}
assertEquals(DataState.Loaded("data" to 5), result)
}
}