Add the SdkClientManager and use a single Client per user for vault (#499)

This commit is contained in:
Brian Yencho 2024-01-05 08:57:49 -06:00 committed by Álison Fernandes
parent 6486a6dc6a
commit 41c35e23dd
12 changed files with 671 additions and 107 deletions

View file

@ -0,0 +1,21 @@
package com.x8bit.bitwarden.data.platform.manager
import com.bitwarden.sdk.Client
/**
* Manages the creation, caching, and destruction of SDK [Client] instances on a per-user basis.
*/
interface SdkClientManager {
/**
* Returns the cached [Client] instance for the given [userId], otherwise creates and caches
* a new one and returns it.
*/
fun getOrCreateClient(userId: String): Client
/**
* Clears any resources from the [Client] associated with the given [userId] and removes it
* from the internal cache.
*/
fun destroyClient(userId: String)
}

View file

@ -0,0 +1,25 @@
package com.x8bit.bitwarden.data.platform.manager
import com.bitwarden.sdk.Client
/**
* Primary implementation of [SdkClientManager].
*/
class SdkClientManagerImpl(
private val clientProvider: () -> Client = { Client(null) },
) : SdkClientManager {
private val userIdToClientMap = mutableMapOf<String, Client>()
override fun getOrCreateClient(
userId: String,
): Client =
userIdToClientMap.getOrPut(key = userId) { clientProvider() }
override fun destroyClient(
userId: String,
) {
userIdToClientMap
.remove(key = userId)
?.close()
}
}

View file

@ -6,6 +6,8 @@ import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthToke
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors
import com.x8bit.bitwarden.data.platform.manager.NetworkConfigManager import com.x8bit.bitwarden.data.platform.manager.NetworkConfigManager
import com.x8bit.bitwarden.data.platform.manager.NetworkConfigManagerImpl import com.x8bit.bitwarden.data.platform.manager.NetworkConfigManagerImpl
import com.x8bit.bitwarden.data.platform.manager.SdkClientManager
import com.x8bit.bitwarden.data.platform.manager.SdkClientManagerImpl
import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager
import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManagerImpl import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManagerImpl
import com.x8bit.bitwarden.data.platform.repository.EnvironmentRepository import com.x8bit.bitwarden.data.platform.repository.EnvironmentRepository
@ -26,6 +28,10 @@ object PlatformManagerModule {
@Singleton @Singleton
fun provideBitwardenDispatchers(): DispatcherManager = DispatcherManagerImpl() fun provideBitwardenDispatchers(): DispatcherManager = DispatcherManagerImpl()
@Provides
@Singleton
fun provideSdkClientManager(): SdkClientManager = SdkClientManagerImpl()
@Provides @Provides
@Singleton @Singleton
fun provideNetworkConfigManager( fun provideNetworkConfigManager(

View file

@ -67,7 +67,11 @@ class GeneratorRepositoryImpl(
.onStart { mutablePasswordHistoryStateFlow.value = LocalDataState.Loading } .onStart { mutablePasswordHistoryStateFlow.value = LocalDataState.Loading }
.map { encryptedPasswordHistoryList -> .map { encryptedPasswordHistoryList ->
val passwordHistories = encryptedPasswordHistoryList.map { it.toPasswordHistory() } val passwordHistories = encryptedPasswordHistoryList.map { it.toPasswordHistory() }
vaultSdkSource.decryptPasswordHistoryList(passwordHistories) vaultSdkSource
.decryptPasswordHistoryList(
userId = userId,
passwordHistoryList = passwordHistories,
)
} }
.onEach { encryptedPasswordHistoryListResult -> .onEach { encryptedPasswordHistoryListResult ->
mutablePasswordHistoryStateFlow.value = encryptedPasswordHistoryListResult.fold( mutablePasswordHistoryStateFlow.value = encryptedPasswordHistoryListResult.fold(
@ -148,7 +152,10 @@ class GeneratorRepositoryImpl(
override suspend fun storePasswordHistory(passwordHistoryView: PasswordHistoryView) { override suspend fun storePasswordHistory(passwordHistoryView: PasswordHistoryView) {
val userId = authDiskSource.userState?.activeUserId ?: return val userId = authDiskSource.userState?.activeUserId ?: return
val encryptedPasswordHistory = vaultSdkSource val encryptedPasswordHistory = vaultSdkSource
.encryptPasswordHistory(passwordHistoryView) .encryptPasswordHistory(
userId = userId,
passwordHistory = passwordHistoryView,
)
.getOrNull() ?: return .getOrNull() ?: return
passwordHistoryDiskSource.insertPasswordHistory( passwordHistoryDiskSource.insertPasswordHistory(
encryptedPasswordHistory.toPasswordHistoryEntity(userId), encryptedPasswordHistory.toPasswordHistoryEntity(userId),

View file

@ -22,84 +22,171 @@ import com.x8bit.bitwarden.data.vault.datasource.sdk.model.InitializeCryptoResul
interface VaultSdkSource { interface VaultSdkSource {
/** /**
* Attempts to initialize cryptography functionality for an individual user for the * Clears any cryptography-related functionality for the given [userId], effectively locking
* Bitwarden SDK with a given [InitUserCryptoRequest]. * the associated vault.
*/ */
suspend fun initializeCrypto(request: InitUserCryptoRequest): Result<InitializeCryptoResult> fun clearCrypto(userId: String)
/**
* Attempts to initialize cryptography functionality for an individual user with the given
* [userId] for the Bitwarden SDK with a given [InitUserCryptoRequest].
*/
suspend fun initializeCrypto(
userId: String,
request: InitUserCryptoRequest,
): Result<InitializeCryptoResult>
/** /**
* Attempts to initialize cryptography functionality for organization data associated with * Attempts to initialize cryptography functionality for organization data associated with
* the current user for the Bitwarden SDK with a given [InitOrgCryptoRequest]. * the user with the given [userId] for the Bitwarden SDK with a given [InitOrgCryptoRequest].
* *
* This should only be called after a successful call to [initializeCrypto]. * This should only be called after a successful call to [initializeCrypto] for the associated
* user.
*/ */
suspend fun initializeOrganizationCrypto( suspend fun initializeOrganizationCrypto(
userId: String,
request: InitOrgCryptoRequest, request: InitOrgCryptoRequest,
): Result<InitializeCryptoResult> ): Result<InitializeCryptoResult>
/** /**
* Encrypts a [CipherView] returning a [Cipher] wrapped in a [Result]. * Encrypts a [CipherView] for the user with the given [userId], returning a [Cipher] wrapped
* in a [Result].
*
* This should only be called after a successful call to [initializeCrypto] for the associated
* user.
*/ */
suspend fun encryptCipher(cipherView: CipherView): Result<Cipher> suspend fun encryptCipher(
userId: String,
cipherView: CipherView,
): Result<Cipher>
/** /**
* Decrypts a [Cipher] returning a [CipherView] wrapped in a [Result]. * Decrypts a [Cipher] for the user with the given [userId], returning a [CipherView] wrapped
* in a [Result].
*
* This should only be called after a successful call to [initializeCrypto] for the associated
* user.
*/ */
suspend fun decryptCipher(cipher: Cipher): Result<CipherView> suspend fun decryptCipher(
userId: String,
cipher: Cipher,
): Result<CipherView>
/** /**
* Decrypts a list of [Cipher]s returning a list of [CipherListView] wrapped in a [Result]. * Decrypts a list of [Cipher]s for the user with the given [userId], returning a list of
* [CipherListView] wrapped in a [Result].
*
* This should only be called after a successful call to [initializeCrypto] for the associated
* user.
*/ */
suspend fun decryptCipherListCollection(cipherList: List<Cipher>): Result<List<CipherListView>> suspend fun decryptCipherListCollection(
userId: String,
cipherList: List<Cipher>,
): Result<List<CipherListView>>
/** /**
* Decrypts a list of [Cipher]s returning a list of [CipherView] wrapped in a [Result]. * Decrypts a list of [Cipher]s for the user with the given [userId], returning a list of
* [CipherView] wrapped in a [Result].
*
* This should only be called after a successful call to [initializeCrypto] for the associated
* user.
*/ */
suspend fun decryptCipherList(cipherList: List<Cipher>): Result<List<CipherView>> suspend fun decryptCipherList(
userId: String,
cipherList: List<Cipher>,
): Result<List<CipherView>>
/** /**
* Decrypts a [Collection] returning a [CollectionView] wrapped in a [Result]. * Decrypts a [Collection] for the user with the given [userId], returning a [CollectionView]
* wrapped in a [Result].
*
* This should only be called after a successful call to [initializeCrypto] for the associated
* user.
*/ */
suspend fun decryptCollection(collection: Collection): Result<CollectionView> suspend fun decryptCollection(
userId: String,
collection: Collection,
): Result<CollectionView>
/** /**
* Decrypts a list of [Collection]s returning a list of [CollectionView] wrapped in a [Result]. * Decrypts a list of [Collection]s for the user with the given [userId], returning a list of
* [CollectionView] wrapped in a [Result].
*
* This should only be called after a successful call to [initializeCrypto] for the associated
* user.
*/ */
suspend fun decryptCollectionList( suspend fun decryptCollectionList(
userId: String,
collectionList: List<Collection>, collectionList: List<Collection>,
): Result<List<CollectionView>> ): Result<List<CollectionView>>
/** /**
* Decrypts a [Send] returning a [SendView] wrapped in a [Result]. * Decrypts a [Send] for the user with the given [userId], returning a [SendView] wrapped in a
* [Result].
*
* This should only be called after a successful call to [initializeCrypto] for the associated
* user.
*/ */
suspend fun decryptSend(send: Send): Result<SendView> suspend fun decryptSend(
userId: String,
send: Send,
): Result<SendView>
/** /**
* Decrypts a list of [Send]s returning a list of [SendView] wrapped in a [Result]. * Decrypts a list of [Send]s for the user with the given [userId], returning a list of
* [SendView] wrapped in a [Result].
*
* This should only be called after a successful call to [initializeCrypto] for the associated
* user.
*/ */
suspend fun decryptSendList(sendList: List<Send>): Result<List<SendView>> suspend fun decryptSendList(
userId: String,
sendList: List<Send>,
): Result<List<SendView>>
/** /**
* Decrypts a [Folder] returning a [FolderView] wrapped in a [Result]. * Decrypts a [Folder] for the user with the given [userId], returning a [FolderView] wrapped
* in a [Result].
*
* This should only be called after a successful call to [initializeCrypto] for the associated
* user.
*/ */
suspend fun decryptFolder(folder: Folder): Result<FolderView> suspend fun decryptFolder(
userId: String,
folder: Folder,
): Result<FolderView>
/** /**
* Decrypts a list of [Folder]s returning a list of [FolderView] wrapped in a [Result]. * Decrypts a list of [Folder]s for the user with the given [userId], returning a list of
* [FolderView] wrapped in a [Result].
*
* This should only be called after a successful call to [initializeCrypto] for the associated
* user.
*/ */
suspend fun decryptFolderList(folderList: List<Folder>): Result<List<FolderView>> suspend fun decryptFolderList(
userId: String,
folderList: List<Folder>,
): Result<List<FolderView>>
/** /**
* Encrypts a given password history item. * Encrypts a given password history item for the user with the given [userId].
*
* This should only be called after a successful call to [initializeCrypto] for the associated
* user.
*/ */
suspend fun encryptPasswordHistory( suspend fun encryptPasswordHistory(
userId: String,
passwordHistory: PasswordHistoryView, passwordHistory: PasswordHistoryView,
): Result<PasswordHistory> ): Result<PasswordHistory>
/** /**
* Decrypts a list of password history items. * Decrypts a list of password history items for the user with the given [userId].
*
* This should only be called after a successful call to [initializeCrypto] for the associated
* user.
*/ */
suspend fun decryptPasswordHistoryList( suspend fun decryptPasswordHistoryList(
userId: String,
passwordHistoryList: List<PasswordHistory>, passwordHistoryList: List<PasswordHistory>,
): Result<List<PasswordHistoryView>> ): Result<List<PasswordHistoryView>>
} }

View file

@ -14,9 +14,9 @@ import com.bitwarden.core.PasswordHistoryView
import com.bitwarden.core.Send import com.bitwarden.core.Send
import com.bitwarden.core.SendView import com.bitwarden.core.SendView
import com.bitwarden.sdk.BitwardenException import com.bitwarden.sdk.BitwardenException
import com.bitwarden.sdk.ClientCrypto import com.bitwarden.sdk.Client
import com.bitwarden.sdk.ClientPasswordHistory
import com.bitwarden.sdk.ClientVault import com.bitwarden.sdk.ClientVault
import com.x8bit.bitwarden.data.platform.manager.SdkClientManager
import com.x8bit.bitwarden.data.vault.datasource.sdk.model.InitializeCryptoResult import com.x8bit.bitwarden.data.vault.datasource.sdk.model.InitializeCryptoResult
/** /**
@ -25,16 +25,21 @@ import com.x8bit.bitwarden.data.vault.datasource.sdk.model.InitializeCryptoResul
*/ */
@Suppress("TooManyFunctions") @Suppress("TooManyFunctions")
class VaultSdkSourceImpl( class VaultSdkSourceImpl(
private val clientVault: ClientVault, private val sdkClientManager: SdkClientManager,
private val clientCrypto: ClientCrypto,
private val clientPasswordHistory: ClientPasswordHistory,
) : VaultSdkSource { ) : VaultSdkSource {
override fun clearCrypto(userId: String) {
sdkClientManager.destroyClient(userId = userId)
}
override suspend fun initializeCrypto( override suspend fun initializeCrypto(
userId: String,
request: InitUserCryptoRequest, request: InitUserCryptoRequest,
): Result<InitializeCryptoResult> = ): Result<InitializeCryptoResult> =
runCatching { runCatching {
try { try {
clientCrypto.initializeUserCrypto(req = request) getClient(userId = userId)
.crypto()
.initializeUserCrypto(req = request)
InitializeCryptoResult.Success InitializeCryptoResult.Success
} catch (exception: BitwardenException) { } catch (exception: BitwardenException) {
// The only truly expected error from the SDK is an incorrect key/password. // The only truly expected error from the SDK is an incorrect key/password.
@ -43,11 +48,14 @@ class VaultSdkSourceImpl(
} }
override suspend fun initializeOrganizationCrypto( override suspend fun initializeOrganizationCrypto(
userId: String,
request: InitOrgCryptoRequest, request: InitOrgCryptoRequest,
): Result<InitializeCryptoResult> = ): Result<InitializeCryptoResult> =
runCatching { runCatching {
try { try {
clientCrypto.initializeOrgCrypto(req = request) getClient(userId = userId)
.crypto()
.initializeOrgCrypto(req = request)
InitializeCryptoResult.Success InitializeCryptoResult.Success
} catch (exception: BitwardenException) { } catch (exception: BitwardenException) {
// The only truly expected error from the SDK is for incorrect keys. // The only truly expected error from the SDK is for incorrect keys.
@ -55,53 +63,140 @@ class VaultSdkSourceImpl(
} }
} }
override suspend fun encryptCipher(cipherView: CipherView): Result<Cipher> = override suspend fun encryptCipher(
runCatching { clientVault.ciphers().encrypt(cipherView) } userId: String,
cipherView: CipherView,
): Result<Cipher> =
runCatching {
getClient(userId = userId)
.vault()
.ciphers()
.encrypt(cipherView)
}
override suspend fun decryptCipher(cipher: Cipher): Result<CipherView> = override suspend fun decryptCipher(
runCatching { clientVault.ciphers().decrypt(cipher) } userId: String,
cipher: Cipher,
): Result<CipherView> =
runCatching {
getClient(userId = userId)
.vault()
.ciphers()
.decrypt(cipher)
}
override suspend fun decryptCipherListCollection( override suspend fun decryptCipherListCollection(
userId: String,
cipherList: List<Cipher>, cipherList: List<Cipher>,
): Result<List<CipherListView>> = ): Result<List<CipherListView>> =
runCatching { clientVault.ciphers().decryptList(cipherList) }
override suspend fun decryptCipherList(cipherList: List<Cipher>): Result<List<CipherView>> =
runCatching { cipherList.map { clientVault.ciphers().decrypt(it) } }
override suspend fun decryptCollection(collection: Collection): Result<CollectionView> =
runCatching { runCatching {
clientVault.collections().decrypt(collection) getClient(userId = userId)
.vault().ciphers()
.decryptList(cipherList)
}
override suspend fun decryptCipherList(
userId: String,
cipherList: List<Cipher>,
): Result<List<CipherView>> =
runCatching {
cipherList.map {
getClient(userId = userId)
.vault()
.ciphers()
.decrypt(it)
}
}
override suspend fun decryptCollection(
userId: String,
collection: Collection,
): Result<CollectionView> =
runCatching {
getClient(userId = userId)
.vault()
.collections()
.decrypt(collection)
} }
override suspend fun decryptCollectionList( override suspend fun decryptCollectionList(
userId: String,
collectionList: List<Collection>, collectionList: List<Collection>,
): Result<List<CollectionView>> = ): Result<List<CollectionView>> =
runCatching { runCatching {
clientVault.collections().decryptList(collectionList) getClient(userId = userId)
.vault()
.collections()
.decryptList(collectionList)
} }
override suspend fun decryptSend(send: Send): Result<SendView> = override suspend fun decryptSend(
runCatching { clientVault.sends().decrypt(send) } userId: String,
send: Send,
): Result<SendView> =
runCatching {
getClient(userId = userId)
.vault()
.sends()
.decrypt(send)
}
override suspend fun decryptSendList(sendList: List<Send>): Result<List<SendView>> = override suspend fun decryptSendList(
runCatching { sendList.map { clientVault.sends().decrypt(it) } } userId: String,
sendList: List<Send>,
): Result<List<SendView>> =
runCatching {
sendList.map {
getClient(userId = userId)
.vault()
.sends()
.decrypt(it)
}
}
override suspend fun decryptFolder(folder: Folder): Result<FolderView> = override suspend fun decryptFolder(
runCatching { clientVault.folders().decrypt(folder) } userId: String,
folder: Folder,
): Result<FolderView> =
runCatching {
getClient(userId = userId)
.vault()
.folders()
.decrypt(folder)
}
override suspend fun decryptFolderList(folderList: List<Folder>): Result<List<FolderView>> = override suspend fun decryptFolderList(
runCatching { clientVault.folders().decryptList(folderList) } userId: String,
folderList: List<Folder>,
): Result<List<FolderView>> =
runCatching {
getClient(userId = userId)
.vault()
.folders()
.decryptList(folderList)
}
override suspend fun encryptPasswordHistory( override suspend fun encryptPasswordHistory(
userId: String,
passwordHistory: PasswordHistoryView, passwordHistory: PasswordHistoryView,
): Result<PasswordHistory> = runCatching { ): Result<PasswordHistory> = runCatching {
clientPasswordHistory.encrypt(passwordHistory) getClient(userId = userId)
.vault()
.passwordHistory()
.encrypt(passwordHistory)
} }
override suspend fun decryptPasswordHistoryList( override suspend fun decryptPasswordHistoryList(
userId: String,
passwordHistoryList: List<PasswordHistory>, passwordHistoryList: List<PasswordHistory>,
): Result<List<PasswordHistoryView>> = runCatching { ): Result<List<PasswordHistoryView>> = runCatching {
clientPasswordHistory.decryptList(passwordHistoryList) getClient(userId = userId)
.vault()
.passwordHistory()
.decryptList(passwordHistoryList)
} }
private fun getClient(
userId: String,
): Client = sdkClientManager.getOrCreateClient(userId = userId)
} }

View file

@ -1,6 +1,6 @@
package com.x8bit.bitwarden.data.vault.datasource.sdk.di package com.x8bit.bitwarden.data.vault.datasource.sdk.di
import com.bitwarden.sdk.Client import com.x8bit.bitwarden.data.platform.manager.SdkClientManager
import com.x8bit.bitwarden.data.vault.datasource.sdk.VaultSdkSource import com.x8bit.bitwarden.data.vault.datasource.sdk.VaultSdkSource
import com.x8bit.bitwarden.data.vault.datasource.sdk.VaultSdkSourceImpl import com.x8bit.bitwarden.data.vault.datasource.sdk.VaultSdkSourceImpl
import dagger.Module import dagger.Module
@ -19,11 +19,9 @@ object VaultSdkModule {
@Provides @Provides
@Singleton @Singleton
fun providesVaultSdkSource( fun providesVaultSdkSource(
client: Client, sdkClientManager: SdkClientManager,
): VaultSdkSource = ): VaultSdkSource =
VaultSdkSourceImpl( VaultSdkSourceImpl(
clientVault = client.vault(), sdkClientManager = sdkClientManager,
clientCrypto = client.crypto(),
clientPasswordHistory = client.vault().passwordHistory(),
) )
} }

View file

@ -320,6 +320,7 @@ class VaultRepositoryImpl(
emit( emit(
vaultSdkSource vaultSdkSource
.initializeCrypto( .initializeCrypto(
userId = userId,
request = InitUserCryptoRequest( request = InitUserCryptoRequest(
kdfParams = kdf, kdfParams = kdf,
email = email, email = email,
@ -336,6 +337,7 @@ class VaultRepositoryImpl(
result is InitializeCryptoResult.Success result is InitializeCryptoResult.Success
) { ) {
vaultSdkSource.initializeOrganizationCrypto( vaultSdkSource.initializeOrganizationCrypto(
userId = userId,
request = InitOrgCryptoRequest( request = InitOrgCryptoRequest(
organizationKeys = organizationKeys, organizationKeys = organizationKeys,
), ),
@ -363,7 +365,10 @@ class VaultRepositoryImpl(
override suspend fun createCipher(cipherView: CipherView): CreateCipherResult = override suspend fun createCipher(cipherView: CipherView): CreateCipherResult =
vaultSdkSource vaultSdkSource
.encryptCipher(cipherView = cipherView) .encryptCipher(
userId = requireNotNull(activeUserId),
cipherView = cipherView,
)
.flatMap { cipher -> .flatMap { cipher ->
ciphersService ciphersService
.createCipher( .createCipher(
@ -385,7 +390,10 @@ class VaultRepositoryImpl(
cipherView: CipherView, cipherView: CipherView,
): UpdateCipherResult = ): UpdateCipherResult =
vaultSdkSource vaultSdkSource
.encryptCipher(cipherView = cipherView) .encryptCipher(
userId = requireNotNull(activeUserId),
cipherView = cipherView,
)
.flatMap { cipher -> .flatMap { cipher ->
ciphersService.updateCipher( ciphersService.updateCipher(
cipherId = cipherId, cipherId = cipherId,
@ -421,6 +429,7 @@ class VaultRepositoryImpl(
// TODO: This is temporary. Eventually this needs to be based on the presence of various // TODO: This is temporary. Eventually this needs to be based on the presence of various
// user keys but this will likely require SDK updates to support this (BIT-1190). // user keys but this will likely require SDK updates to support this (BIT-1190).
private fun setVaultToLocked(userId: String) { private fun setVaultToLocked(userId: String) {
vaultSdkSource.clearCrypto(userId = userId)
mutableVaultStateStateFlow.update { mutableVaultStateStateFlow.update {
it.copy( it.copy(
unlockedVaultUserIds = it.unlockedVaultUserIds - userId, unlockedVaultUserIds = it.unlockedVaultUserIds - userId,
@ -473,6 +482,7 @@ class VaultRepositoryImpl(
// the return type here. // the return type here.
vaultSdkSource vaultSdkSource
.initializeOrganizationCrypto( .initializeOrganizationCrypto(
userId = syncResponse.profile.id,
request = InitOrgCryptoRequest( request = InitOrgCryptoRequest(
organizationKeys = organizationKeys, organizationKeys = organizationKeys,
), ),
@ -484,10 +494,15 @@ class VaultRepositoryImpl(
): Flow<DataState<List<CipherView>>> = ): Flow<DataState<List<CipherView>>> =
vaultDiskSource vaultDiskSource
.getCiphers(userId = userId) .getCiphers(userId = userId)
.onStart { mutableCiphersStateFlow.value = DataState.Loading } .onStart {
mutableCiphersStateFlow.value = DataState.Loading
}
.map { .map {
vaultSdkSource vaultSdkSource
.decryptCipherList(cipherList = it.toEncryptedSdkCipherList()) .decryptCipherList(
userId = userId,
cipherList = it.toEncryptedSdkCipherList(),
)
.fold( .fold(
onSuccess = { ciphers -> DataState.Loaded(ciphers) }, onSuccess = { ciphers -> DataState.Loaded(ciphers) },
onFailure = { throwable -> DataState.Error(throwable) }, onFailure = { throwable -> DataState.Error(throwable) },
@ -503,7 +518,10 @@ class VaultRepositoryImpl(
.onStart { mutableFoldersStateFlow.value = DataState.Loading } .onStart { mutableFoldersStateFlow.value = DataState.Loading }
.map { .map {
vaultSdkSource vaultSdkSource
.decryptFolderList(folderList = it.toEncryptedSdkFolderList()) .decryptFolderList(
userId = userId,
folderList = it.toEncryptedSdkFolderList(),
)
.fold( .fold(
onSuccess = { folders -> DataState.Loaded(folders) }, onSuccess = { folders -> DataState.Loaded(folders) },
onFailure = { throwable -> DataState.Error(throwable) }, onFailure = { throwable -> DataState.Error(throwable) },
@ -519,7 +537,10 @@ class VaultRepositoryImpl(
.onStart { mutableCollectionsStateFlow.value = DataState.Loading } .onStart { mutableCollectionsStateFlow.value = DataState.Loading }
.map { .map {
vaultSdkSource vaultSdkSource
.decryptCollectionList(collectionList = it.toEncryptedSdkCollectionList()) .decryptCollectionList(
userId = userId,
collectionList = it.toEncryptedSdkCollectionList(),
)
.fold( .fold(
onSuccess = { collections -> DataState.Loaded(collections) }, onSuccess = { collections -> DataState.Loaded(collections) },
onFailure = { throwable -> DataState.Error(throwable) }, onFailure = { throwable -> DataState.Error(throwable) },
@ -535,7 +556,10 @@ class VaultRepositoryImpl(
.onStart { mutableSendDataStateFlow.value = DataState.Loading } .onStart { mutableSendDataStateFlow.value = DataState.Loading }
.map { .map {
vaultSdkSource vaultSdkSource
.decryptSendList(sendList = it.toEncryptedSdkSendList()) .decryptSendList(
userId = userId,
sendList = it.toEncryptedSdkSendList(),
)
.fold( .fold(
onSuccess = { sends -> DataState.Loaded(SendData(sends)) }, onSuccess = { sends -> DataState.Loaded(SendData(sends)) },
onFailure = { throwable -> DataState.Error(throwable) }, onFailure = { throwable -> DataState.Error(throwable) },

View file

@ -0,0 +1,44 @@
package com.x8bit.bitwarden.data.platform.manager
import io.mockk.mockk
import io.mockk.verify
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNotEquals
import org.junit.jupiter.api.Test
class SdkClientManagerTest {
private val sdkClientManager = SdkClientManagerImpl(
clientProvider = { mockk(relaxed = true) },
)
@Suppress("MaxLineLength")
@Test
fun `getOrCreateClient should create a new client for each userId and return a cached client for subsequent calls`() {
val userId = "userId"
val firstClient = sdkClientManager.getOrCreateClient(userId = userId)
// Additional calls for the same userId return the same value
val secondClient = sdkClientManager.getOrCreateClient(userId = userId)
assertEquals(firstClient, secondClient)
// Additional calls for different userIds should return different values
val otherUserId = "otherUserId"
val thirdClient = sdkClientManager.getOrCreateClient(userId = otherUserId)
assertNotEquals(firstClient, thirdClient)
}
@Test
fun `destroyClient should call close on the Client and remove it from the cache`() {
val userId = "userId"
val firstClient = sdkClientManager.getOrCreateClient(userId = userId)
sdkClientManager.destroyClient(userId = userId)
verify { firstClient.close() }
// New calls for the same userId should return different values
val secondClient = sdkClientManager.getOrCreateClient(userId = userId)
assertNotEquals(firstClient, secondClient)
}
}

View file

@ -105,7 +105,7 @@ class GeneratorRepositoryTest {
coEvery { generatorSdkSource.generatePassword(request) } returns coEvery { generatorSdkSource.generatePassword(request) } returns
Result.success(generatedPassword) Result.success(generatedPassword)
coEvery { vaultSdkSource.encryptPasswordHistory(any()) } returns coEvery { vaultSdkSource.encryptPasswordHistory(any(), any()) } returns
Result.success(encryptedPasswordHistory) Result.success(encryptedPasswordHistory)
coEvery { passwordHistoryDiskSource.insertPasswordHistory(any()) } just runs coEvery { passwordHistoryDiskSource.insertPasswordHistory(any()) } just runs
@ -156,7 +156,7 @@ class GeneratorRepositoryTest {
coEvery { generatorSdkSource.generatePassword(request) } returns coEvery { generatorSdkSource.generatePassword(request) } returns
Result.success(generatedPassword) Result.success(generatedPassword)
coEvery { vaultSdkSource.encryptPasswordHistory(any()) } returns coEvery { vaultSdkSource.encryptPasswordHistory(any(), any()) } returns
Result.success(encryptedPasswordHistory) Result.success(encryptedPasswordHistory)
coEvery { passwordHistoryDiskSource.insertPasswordHistory(any()) } just runs coEvery { passwordHistoryDiskSource.insertPasswordHistory(any()) } just runs
@ -220,7 +220,7 @@ class GeneratorRepositoryTest {
coEvery { generatorSdkSource.generatePassphrase(request) } returns coEvery { generatorSdkSource.generatePassphrase(request) } returns
Result.success(generatedPassphrase) Result.success(generatedPassphrase)
coEvery { vaultSdkSource.encryptPasswordHistory(any()) } returns coEvery { vaultSdkSource.encryptPasswordHistory(any(), any()) } returns
Result.success(encryptedPasswordHistory) Result.success(encryptedPasswordHistory)
coEvery { passwordHistoryDiskSource.insertPasswordHistory(any()) } just runs coEvery { passwordHistoryDiskSource.insertPasswordHistory(any()) } just runs
@ -232,7 +232,7 @@ class GeneratorRepositoryTest {
(result as GeneratedPassphraseResult.Success).generatedString, (result as GeneratedPassphraseResult.Success).generatedString,
) )
coVerify { generatorSdkSource.generatePassphrase(request) } coVerify { generatorSdkSource.generatePassphrase(request) }
coVerify { vaultSdkSource.encryptPasswordHistory(any()) } coVerify { vaultSdkSource.encryptPasswordHistory(any(), any()) }
coVerify { coVerify {
passwordHistoryDiskSource.insertPasswordHistory( passwordHistoryDiskSource.insertPasswordHistory(
encryptedPasswordHistory.toPasswordHistoryEntity(userId), encryptedPasswordHistory.toPasswordHistoryEntity(userId),
@ -406,7 +406,12 @@ class GeneratorRepositoryTest {
coEvery { authDiskSource.userState?.activeUserId } returns testUserId coEvery { authDiskSource.userState?.activeUserId } returns testUserId
coEvery { vaultSdkSource.encryptPasswordHistory(passwordHistoryView) } returns coEvery {
vaultSdkSource.encryptPasswordHistory(
userId = testUserId,
passwordHistory = passwordHistoryView,
)
} returns
Result.success(encryptedPasswordHistory) Result.success(encryptedPasswordHistory)
coEvery { coEvery {
@ -415,7 +420,12 @@ class GeneratorRepositoryTest {
repository.storePasswordHistory(passwordHistoryView) repository.storePasswordHistory(passwordHistoryView)
coVerify { vaultSdkSource.encryptPasswordHistory(passwordHistoryView) } coVerify {
vaultSdkSource.encryptPasswordHistory(
userId = testUserId,
passwordHistory = passwordHistoryView,
)
}
coVerify { passwordHistoryDiskSource.insertPasswordHistory(expectedPasswordHistoryEntity) } coVerify { passwordHistoryDiskSource.insertPasswordHistory(expectedPasswordHistoryEntity) }
} }
@ -451,7 +461,7 @@ class GeneratorRepositoryTest {
} returns flowOf(encryptedPasswordHistoryEntities) } returns flowOf(encryptedPasswordHistoryEntities)
coEvery { coEvery {
vaultSdkSource.decryptPasswordHistoryList(any()) vaultSdkSource.decryptPasswordHistoryList(any(), any())
} returns Result.success(decryptedPasswordHistoryList) } returns Result.success(decryptedPasswordHistoryList)
val historyFlow = repository.passwordHistoryStateFlow val historyFlow = repository.passwordHistoryStateFlow
@ -467,7 +477,7 @@ class GeneratorRepositoryTest {
passwordHistoryDiskSource.getPasswordHistoriesForUser(USER_STATE.activeUserId) passwordHistoryDiskSource.getPasswordHistoriesForUser(USER_STATE.activeUserId)
} }
coVerify { vaultSdkSource.decryptPasswordHistoryList(any()) } coVerify { vaultSdkSource.decryptPasswordHistoryList(any(), any()) }
} }
@Test @Test

View file

@ -14,32 +14,56 @@ import com.bitwarden.core.PasswordHistoryView
import com.bitwarden.core.Send import com.bitwarden.core.Send
import com.bitwarden.core.SendView import com.bitwarden.core.SendView
import com.bitwarden.sdk.BitwardenException import com.bitwarden.sdk.BitwardenException
import com.bitwarden.sdk.Client
import com.bitwarden.sdk.ClientCrypto import com.bitwarden.sdk.ClientCrypto
import com.bitwarden.sdk.ClientPasswordHistory import com.bitwarden.sdk.ClientPasswordHistory
import com.bitwarden.sdk.ClientVault import com.bitwarden.sdk.ClientVault
import com.x8bit.bitwarden.data.platform.manager.SdkClientManager
import com.x8bit.bitwarden.data.platform.util.asFailure import com.x8bit.bitwarden.data.platform.util.asFailure
import com.x8bit.bitwarden.data.platform.util.asSuccess import com.x8bit.bitwarden.data.platform.util.asSuccess
import com.x8bit.bitwarden.data.vault.datasource.sdk.model.InitializeCryptoResult import com.x8bit.bitwarden.data.vault.datasource.sdk.model.InitializeCryptoResult
import io.mockk.coEvery import io.mockk.coEvery
import io.mockk.coVerify import io.mockk.coVerify
import io.mockk.every
import io.mockk.just
import io.mockk.mockk import io.mockk.mockk
import io.mockk.runs
import io.mockk.verify
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
class VaultSdkSourceTest { class VaultSdkSourceTest {
private val clientVault = mockk<ClientVault>()
private val clientCrypto = mockk<ClientCrypto>() private val clientCrypto = mockk<ClientCrypto>()
private val clientPasswordHistory = mockk<ClientPasswordHistory>() private val clientPasswordHistory = mockk<ClientPasswordHistory>()
private val clientVault = mockk<ClientVault>() {
every { passwordHistory() } returns clientPasswordHistory
}
private val client = mockk<Client>() {
every { vault() } returns clientVault
every { crypto() } returns clientCrypto
}
private val sdkClientManager = mockk<SdkClientManager> {
every { getOrCreateClient(any()) } returns client
every { destroyClient(any()) } just runs
}
private val vaultSdkSource: VaultSdkSource = VaultSdkSourceImpl( private val vaultSdkSource: VaultSdkSource = VaultSdkSourceImpl(
clientVault = clientVault, sdkClientManager = sdkClientManager,
clientCrypto = clientCrypto,
clientPasswordHistory = clientPasswordHistory,
) )
@Test
fun `clearCrypto should destroy the associated client via the SDK Manager`() {
val userId = "userId"
vaultSdkSource.clearCrypto(userId = userId)
verify { sdkClientManager.destroyClient(userId = userId) }
}
@Test @Test
fun `initializeUserCrypto with sdk success should return InitializeCryptoResult Success`() = fun `initializeUserCrypto with sdk success should return InitializeCryptoResult Success`() =
runBlocking { runBlocking {
val userId = "userId"
val mockInitCryptoRequest = mockk<InitUserCryptoRequest>() val mockInitCryptoRequest = mockk<InitUserCryptoRequest>()
coEvery { coEvery {
clientCrypto.initializeUserCrypto( clientCrypto.initializeUserCrypto(
@ -47,6 +71,7 @@ class VaultSdkSourceTest {
) )
} returns Unit } returns Unit
val result = vaultSdkSource.initializeCrypto( val result = vaultSdkSource.initializeCrypto(
userId = userId,
request = mockInitCryptoRequest, request = mockInitCryptoRequest,
) )
assertEquals( assertEquals(
@ -58,10 +83,12 @@ class VaultSdkSourceTest {
req = mockInitCryptoRequest, req = mockInitCryptoRequest,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
fun `initializeUserCrypto with sdk failure should return failure`() = runBlocking { fun `initializeUserCrypto with sdk failure should return failure`() = runBlocking {
val userId = "userId"
val mockInitCryptoRequest = mockk<InitUserCryptoRequest>() val mockInitCryptoRequest = mockk<InitUserCryptoRequest>()
val expectedException = IllegalStateException("mock") val expectedException = IllegalStateException("mock")
coEvery { coEvery {
@ -70,6 +97,7 @@ class VaultSdkSourceTest {
) )
} throws expectedException } throws expectedException
val result = vaultSdkSource.initializeCrypto( val result = vaultSdkSource.initializeCrypto(
userId = userId,
request = mockInitCryptoRequest, request = mockInitCryptoRequest,
) )
assertEquals( assertEquals(
@ -81,11 +109,13 @@ class VaultSdkSourceTest {
req = mockInitCryptoRequest, req = mockInitCryptoRequest,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
fun `initializeUserCrypto with BitwardenException failure should return AuthenticationError`() = fun `initializeUserCrypto with BitwardenException failure should return AuthenticationError`() =
runBlocking { runBlocking {
val userId = "userId"
val mockInitCryptoRequest = mockk<InitUserCryptoRequest>() val mockInitCryptoRequest = mockk<InitUserCryptoRequest>()
val expectedException = BitwardenException.E(message = "") val expectedException = BitwardenException.E(message = "")
coEvery { coEvery {
@ -94,6 +124,7 @@ class VaultSdkSourceTest {
) )
} throws expectedException } throws expectedException
val result = vaultSdkSource.initializeCrypto( val result = vaultSdkSource.initializeCrypto(
userId = userId,
request = mockInitCryptoRequest, request = mockInitCryptoRequest,
) )
assertEquals( assertEquals(
@ -105,11 +136,13 @@ class VaultSdkSourceTest {
req = mockInitCryptoRequest, req = mockInitCryptoRequest,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
fun `initializeOrgCrypto with sdk success should return InitializeCryptoResult Success`() = fun `initializeOrgCrypto with sdk success should return InitializeCryptoResult Success`() =
runBlocking { runBlocking {
val userId = "userId"
val mockInitCryptoRequest = mockk<InitOrgCryptoRequest>() val mockInitCryptoRequest = mockk<InitOrgCryptoRequest>()
coEvery { coEvery {
clientCrypto.initializeOrgCrypto( clientCrypto.initializeOrgCrypto(
@ -117,6 +150,7 @@ class VaultSdkSourceTest {
) )
} returns Unit } returns Unit
val result = vaultSdkSource.initializeOrganizationCrypto( val result = vaultSdkSource.initializeOrganizationCrypto(
userId = userId,
request = mockInitCryptoRequest, request = mockInitCryptoRequest,
) )
assertEquals( assertEquals(
@ -128,10 +162,12 @@ class VaultSdkSourceTest {
req = mockInitCryptoRequest, req = mockInitCryptoRequest,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
fun `initializeOrgCrypto with sdk failure should return failure`() = runBlocking { fun `initializeOrgCrypto with sdk failure should return failure`() = runBlocking {
val userId = "userId"
val mockInitCryptoRequest = mockk<InitOrgCryptoRequest>() val mockInitCryptoRequest = mockk<InitOrgCryptoRequest>()
val expectedException = IllegalStateException("mock") val expectedException = IllegalStateException("mock")
coEvery { coEvery {
@ -140,6 +176,7 @@ class VaultSdkSourceTest {
) )
} throws expectedException } throws expectedException
val result = vaultSdkSource.initializeOrganizationCrypto( val result = vaultSdkSource.initializeOrganizationCrypto(
userId = userId,
request = mockInitCryptoRequest, request = mockInitCryptoRequest,
) )
assertEquals( assertEquals(
@ -151,11 +188,13 @@ class VaultSdkSourceTest {
req = mockInitCryptoRequest, req = mockInitCryptoRequest,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
fun `initializeOrgCrypto with BitwardenException failure should return AuthenticationError`() = fun `initializeOrgCrypto with BitwardenException failure should return AuthenticationError`() =
runBlocking { runBlocking {
val userId = "userId"
val mockInitCryptoRequest = mockk<InitOrgCryptoRequest>() val mockInitCryptoRequest = mockk<InitOrgCryptoRequest>()
val expectedException = BitwardenException.E(message = "") val expectedException = BitwardenException.E(message = "")
coEvery { coEvery {
@ -164,6 +203,7 @@ class VaultSdkSourceTest {
) )
} throws expectedException } throws expectedException
val result = vaultSdkSource.initializeOrganizationCrypto( val result = vaultSdkSource.initializeOrganizationCrypto(
userId = userId,
request = mockInitCryptoRequest, request = mockInitCryptoRequest,
) )
assertEquals( assertEquals(
@ -175,10 +215,12 @@ class VaultSdkSourceTest {
req = mockInitCryptoRequest, req = mockInitCryptoRequest,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
fun `decryptCipher should call SDK and return a Result with correct data`() = runBlocking { fun `decryptCipher should call SDK and return a Result with correct data`() = runBlocking {
val userId = "userId"
val mockCipher = mockk<CipherView>() val mockCipher = mockk<CipherView>()
val expectedResult = mockk<Cipher>() val expectedResult = mockk<Cipher>()
coEvery { coEvery {
@ -187,6 +229,7 @@ class VaultSdkSourceTest {
) )
} returns expectedResult } returns expectedResult
val result = vaultSdkSource.encryptCipher( val result = vaultSdkSource.encryptCipher(
userId = userId,
cipherView = mockCipher, cipherView = mockCipher,
) )
assertEquals( assertEquals(
@ -198,10 +241,12 @@ class VaultSdkSourceTest {
cipherView = mockCipher, cipherView = mockCipher,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
fun `Cipher decrypt should call SDK and return a Result with correct data`() = runBlocking { fun `Cipher decrypt should call SDK and return a Result with correct data`() = runBlocking {
val userId = "userId"
val mockCipher = mockk<Cipher>() val mockCipher = mockk<Cipher>()
val expectedResult = mockk<CipherView>() val expectedResult = mockk<CipherView>()
coEvery { coEvery {
@ -210,6 +255,7 @@ class VaultSdkSourceTest {
) )
} returns expectedResult } returns expectedResult
val result = vaultSdkSource.decryptCipher( val result = vaultSdkSource.decryptCipher(
userId = userId,
cipher = mockCipher, cipher = mockCipher,
) )
assertEquals( assertEquals(
@ -221,11 +267,13 @@ class VaultSdkSourceTest {
cipher = mockCipher, cipher = mockCipher,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
fun `Cipher decryptListCollection should call SDK and return a Result with correct data`() = fun `Cipher decryptListCollection should call SDK and return a Result with correct data`() =
runBlocking { runBlocking {
val userId = "userId"
val mockCiphers = mockk<List<Cipher>>() val mockCiphers = mockk<List<Cipher>>()
val expectedResult = mockk<List<CipherListView>>() val expectedResult = mockk<List<CipherListView>>()
coEvery { coEvery {
@ -234,6 +282,7 @@ class VaultSdkSourceTest {
) )
} returns expectedResult } returns expectedResult
val result = vaultSdkSource.decryptCipherListCollection( val result = vaultSdkSource.decryptCipherListCollection(
userId = userId,
cipherList = mockCiphers, cipherList = mockCiphers,
) )
assertEquals( assertEquals(
@ -245,10 +294,12 @@ class VaultSdkSourceTest {
ciphers = mockCiphers, ciphers = mockCiphers,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
fun `Cipher decryptList should call SDK and return a Result with correct data`() = runBlocking { fun `Cipher decryptList should call SDK and return a Result with correct data`() = runBlocking {
val userId = "userId"
val mockCiphers = mockk<Cipher>() val mockCiphers = mockk<Cipher>()
val expectedResult = mockk<CipherView>() val expectedResult = mockk<CipherView>()
coEvery { coEvery {
@ -257,6 +308,7 @@ class VaultSdkSourceTest {
) )
} returns expectedResult } returns expectedResult
val result = vaultSdkSource.decryptCipherList( val result = vaultSdkSource.decryptCipherList(
userId = userId,
cipherList = listOf(mockCiphers), cipherList = listOf(mockCiphers),
) )
assertEquals( assertEquals(
@ -268,11 +320,13 @@ class VaultSdkSourceTest {
cipher = mockCiphers, cipher = mockCiphers,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
fun `decryptCollection should call SDK and return correct data wrapped in a Result`() = fun `decryptCollection should call SDK and return correct data wrapped in a Result`() =
runBlocking { runBlocking {
val userId = "userId"
val mockCollection = mockk<Collection>() val mockCollection = mockk<Collection>()
val expectedResult = mockk<CollectionView>() val expectedResult = mockk<CollectionView>()
coEvery { coEvery {
@ -281,6 +335,7 @@ class VaultSdkSourceTest {
) )
} returns expectedResult } returns expectedResult
val result = vaultSdkSource.decryptCollection( val result = vaultSdkSource.decryptCollection(
userId = userId,
collection = mockCollection, collection = mockCollection,
) )
assertEquals( assertEquals(
@ -291,11 +346,13 @@ class VaultSdkSourceTest {
collection = mockCollection, collection = mockCollection,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
fun `decryptCollectionList should call SDK and return correct data wrapped in a Result`() = fun `decryptCollectionList should call SDK and return correct data wrapped in a Result`() =
runBlocking { runBlocking {
val userId = "userId"
val mockCollectionsList = mockk<List<Collection>>() val mockCollectionsList = mockk<List<Collection>>()
val expectedResult = mockk<List<CollectionView>>() val expectedResult = mockk<List<CollectionView>>()
coEvery { coEvery {
@ -304,6 +361,7 @@ class VaultSdkSourceTest {
) )
} returns expectedResult } returns expectedResult
val result = vaultSdkSource.decryptCollectionList( val result = vaultSdkSource.decryptCollectionList(
userId = userId,
collectionList = mockCollectionsList, collectionList = mockCollectionsList,
) )
assertEquals( assertEquals(
@ -315,11 +373,13 @@ class VaultSdkSourceTest {
collections = mockCollectionsList, collections = mockCollectionsList,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
fun `decryptSendList should call SDK and return correct data wrapped in a Result`() = fun `decryptSendList should call SDK and return correct data wrapped in a Result`() =
runBlocking { runBlocking {
val userId = "userId"
val mockSend = mockk<Send>() val mockSend = mockk<Send>()
val expectedResult = mockk<SendView>() val expectedResult = mockk<SendView>()
coEvery { coEvery {
@ -328,6 +388,7 @@ class VaultSdkSourceTest {
) )
} returns expectedResult } returns expectedResult
val result = vaultSdkSource.decryptSendList( val result = vaultSdkSource.decryptSendList(
userId = userId,
sendList = listOf(mockSend), sendList = listOf(mockSend),
) )
assertEquals( assertEquals(
@ -339,11 +400,13 @@ class VaultSdkSourceTest {
send = mockSend, send = mockSend,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
fun `decryptSend should call SDK and return correct data wrapped in a Result`() = fun `decryptSend should call SDK and return correct data wrapped in a Result`() =
runBlocking { runBlocking {
val userId = "userId"
val mockSend = mockk<Send>() val mockSend = mockk<Send>()
val expectedResult = mockk<SendView>() val expectedResult = mockk<SendView>()
coEvery { coEvery {
@ -352,6 +415,7 @@ class VaultSdkSourceTest {
) )
} returns expectedResult } returns expectedResult
val result = vaultSdkSource.decryptSend( val result = vaultSdkSource.decryptSend(
userId = userId,
send = mockSend, send = mockSend,
) )
assertEquals( assertEquals(
@ -362,10 +426,12 @@ class VaultSdkSourceTest {
send = mockSend, send = mockSend,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
fun `Folder decrypt should call SDK and return a Result with correct data`() = runBlocking { fun `Folder decrypt should call SDK and return a Result with correct data`() = runBlocking {
val userId = "userId"
val mockFolder = mockk<Folder>() val mockFolder = mockk<Folder>()
val expectedResult = mockk<FolderView>() val expectedResult = mockk<FolderView>()
coEvery { coEvery {
@ -374,6 +440,7 @@ class VaultSdkSourceTest {
) )
} returns expectedResult } returns expectedResult
val result = vaultSdkSource.decryptFolder( val result = vaultSdkSource.decryptFolder(
userId = userId,
folder = mockFolder, folder = mockFolder,
) )
assertEquals( assertEquals(
@ -385,10 +452,12 @@ class VaultSdkSourceTest {
folder = mockFolder, folder = mockFolder,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
fun `Folder decryptList should call SDK and return a Result with correct data`() = runBlocking { fun `Folder decryptList should call SDK and return a Result with correct data`() = runBlocking {
val userId = "userId"
val mockFolders = mockk<List<Folder>>() val mockFolders = mockk<List<Folder>>()
val expectedResult = mockk<List<FolderView>>() val expectedResult = mockk<List<FolderView>>()
coEvery { coEvery {
@ -397,6 +466,7 @@ class VaultSdkSourceTest {
) )
} returns expectedResult } returns expectedResult
val result = vaultSdkSource.decryptFolderList( val result = vaultSdkSource.decryptFolderList(
userId = userId,
folderList = mockFolders, folderList = mockFolders,
) )
assertEquals( assertEquals(
@ -408,11 +478,13 @@ class VaultSdkSourceTest {
folders = mockFolders, folders = mockFolders,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
fun `encryptPasswordHistory should call SDK and return a Result with correct data`() = fun `encryptPasswordHistory should call SDK and return a Result with correct data`() =
runBlocking { runBlocking {
val userId = "userId"
val mockPasswordHistoryView = mockk<PasswordHistoryView>() val mockPasswordHistoryView = mockk<PasswordHistoryView>()
val expectedResult = mockk<PasswordHistory>() val expectedResult = mockk<PasswordHistory>()
coEvery { coEvery {
@ -422,6 +494,7 @@ class VaultSdkSourceTest {
} returns expectedResult } returns expectedResult
val result = vaultSdkSource.encryptPasswordHistory( val result = vaultSdkSource.encryptPasswordHistory(
userId = userId,
passwordHistory = mockPasswordHistoryView, passwordHistory = mockPasswordHistoryView,
) )
@ -431,11 +504,13 @@ class VaultSdkSourceTest {
passwordHistory = mockPasswordHistoryView, passwordHistory = mockPasswordHistoryView,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
fun `decryptPasswordHistoryList should call SDK and return a Result with correct data`() = fun `decryptPasswordHistoryList should call SDK and return a Result with correct data`() =
runBlocking { runBlocking {
val userId = "userId"
val mockPasswordHistoryList = mockk<List<PasswordHistory>>() val mockPasswordHistoryList = mockk<List<PasswordHistory>>()
val expectedResult = mockk<List<PasswordHistoryView>>() val expectedResult = mockk<List<PasswordHistoryView>>()
coEvery { coEvery {
@ -445,6 +520,7 @@ class VaultSdkSourceTest {
} returns expectedResult } returns expectedResult
val result = vaultSdkSource.decryptPasswordHistoryList( val result = vaultSdkSource.decryptPasswordHistoryList(
userId = userId,
passwordHistoryList = mockPasswordHistoryList, passwordHistoryList = mockPasswordHistoryList,
) )
@ -454,5 +530,6 @@ class VaultSdkSourceTest {
list = mockPasswordHistoryList, list = mockPasswordHistoryList,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
} }

View file

@ -59,6 +59,7 @@ import io.mockk.every
import io.mockk.just import io.mockk.just
import io.mockk.mockk import io.mockk.mockk
import io.mockk.runs import io.mockk.runs
import io.mockk.verify
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.cancel import kotlinx.coroutines.cancel
@ -77,7 +78,9 @@ class VaultRepositoryTest {
private val syncService: SyncService = mockk() private val syncService: SyncService = mockk()
private val ciphersService: CiphersService = mockk() private val ciphersService: CiphersService = mockk()
private val vaultDiskSource: VaultDiskSource = mockk() private val vaultDiskSource: VaultDiskSource = mockk()
private val vaultSdkSource: VaultSdkSource = mockk() private val vaultSdkSource: VaultSdkSource = mockk {
every { clearCrypto(userId = any()) } just runs
}
private val vaultRepository = VaultRepositoryImpl( private val vaultRepository = VaultRepositoryImpl(
syncService = syncService, syncService = syncService,
ciphersService = ciphersService, ciphersService = ciphersService,
@ -91,6 +94,7 @@ class VaultRepositoryTest {
fun `ciphersStateFlow should emit decrypted list of ciphers when decryptCipherList succeeds`() = fun `ciphersStateFlow should emit decrypted list of ciphers when decryptCipherList succeeds`() =
runTest { runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
val mockCipherList = listOf(createMockCipher(number = 1)) val mockCipherList = listOf(createMockCipher(number = 1))
val mockEncryptedCipherList = mockCipherList.toEncryptedSdkCipherList() val mockEncryptedCipherList = mockCipherList.toEncryptedSdkCipherList()
val mockCipherViewList = listOf(createMockCipherView(number = 1)) val mockCipherViewList = listOf(createMockCipherView(number = 1))
@ -100,7 +104,10 @@ class VaultRepositoryTest {
vaultDiskSource.getCiphers(userId = MOCK_USER_STATE.activeUserId) vaultDiskSource.getCiphers(userId = MOCK_USER_STATE.activeUserId)
} returns mutableCiphersStateFlow } returns mutableCiphersStateFlow
coEvery { coEvery {
vaultSdkSource.decryptCipherList(mockEncryptedCipherList) vaultSdkSource.decryptCipherList(
userId = userId,
cipherList = mockEncryptedCipherList,
)
} returns mockCipherViewList.asSuccess() } returns mockCipherViewList.asSuccess()
vaultRepository vaultRepository
@ -115,6 +122,7 @@ class VaultRepositoryTest {
@Test @Test
fun `ciphersStateFlow should emit an error when decryptCipherList fails`() = runTest { fun `ciphersStateFlow should emit an error when decryptCipherList fails`() = runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
val throwable = Throwable("Fail") val throwable = Throwable("Fail")
val mockCipherList = listOf(createMockCipher(number = 1)) val mockCipherList = listOf(createMockCipher(number = 1))
val mockEncryptedCipherList = mockCipherList.toEncryptedSdkCipherList() val mockEncryptedCipherList = mockCipherList.toEncryptedSdkCipherList()
@ -124,7 +132,10 @@ class VaultRepositoryTest {
vaultDiskSource.getCiphers(userId = MOCK_USER_STATE.activeUserId) vaultDiskSource.getCiphers(userId = MOCK_USER_STATE.activeUserId)
} returns mutableCiphersStateFlow } returns mutableCiphersStateFlow
coEvery { coEvery {
vaultSdkSource.decryptCipherList(mockEncryptedCipherList) vaultSdkSource.decryptCipherList(
userId = userId,
cipherList = mockEncryptedCipherList,
)
} returns throwable.asFailure() } returns throwable.asFailure()
vaultRepository vaultRepository
@ -141,6 +152,7 @@ class VaultRepositoryTest {
fun `collectionsStateFlow should emit decrypted list of collections when decryptCollectionList succeeds`() = fun `collectionsStateFlow should emit decrypted list of collections when decryptCollectionList succeeds`() =
runTest { runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
val mockCollectionList = listOf(createMockCollection(number = 1)) val mockCollectionList = listOf(createMockCollection(number = 1))
val mockEncryptedCollectionList = mockCollectionList.toEncryptedSdkCollectionList() val mockEncryptedCollectionList = mockCollectionList.toEncryptedSdkCollectionList()
val mockCollectionViewList = listOf(createMockCollectionView(number = 1)) val mockCollectionViewList = listOf(createMockCollectionView(number = 1))
@ -150,7 +162,10 @@ class VaultRepositoryTest {
vaultDiskSource.getCollections(userId = MOCK_USER_STATE.activeUserId) vaultDiskSource.getCollections(userId = MOCK_USER_STATE.activeUserId)
} returns mutableCollectionsStateFlow } returns mutableCollectionsStateFlow
coEvery { coEvery {
vaultSdkSource.decryptCollectionList(mockEncryptedCollectionList) vaultSdkSource.decryptCollectionList(
userId = userId,
collectionList = mockEncryptedCollectionList,
)
} returns mockCollectionViewList.asSuccess() } returns mockCollectionViewList.asSuccess()
vaultRepository vaultRepository
@ -165,6 +180,7 @@ class VaultRepositoryTest {
@Test @Test
fun `collectionsStateFlow should emit an error when decryptCollectionList fails`() = runTest { fun `collectionsStateFlow should emit an error when decryptCollectionList fails`() = runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
val throwable = Throwable("Fail") val throwable = Throwable("Fail")
val mockCollectionList = listOf(createMockCollection(number = 1)) val mockCollectionList = listOf(createMockCollection(number = 1))
val mockEncryptedCollectionList = mockCollectionList.toEncryptedSdkCollectionList() val mockEncryptedCollectionList = mockCollectionList.toEncryptedSdkCollectionList()
@ -174,7 +190,10 @@ class VaultRepositoryTest {
vaultDiskSource.getCollections(userId = MOCK_USER_STATE.activeUserId) vaultDiskSource.getCollections(userId = MOCK_USER_STATE.activeUserId)
} returns mutableCollectionStateFlow } returns mutableCollectionStateFlow
coEvery { coEvery {
vaultSdkSource.decryptCollectionList(mockEncryptedCollectionList) vaultSdkSource.decryptCollectionList(
userId = userId,
collectionList = mockEncryptedCollectionList,
)
} returns throwable.asFailure() } returns throwable.asFailure()
vaultRepository vaultRepository
@ -191,6 +210,7 @@ class VaultRepositoryTest {
fun `foldersStateFlow should emit decrypted list of folders when decryptFolderList succeeds`() = fun `foldersStateFlow should emit decrypted list of folders when decryptFolderList succeeds`() =
runTest { runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
val mockFolderList = listOf(createMockFolder(number = 1)) val mockFolderList = listOf(createMockFolder(number = 1))
val mockEncryptedFolderList = mockFolderList.toEncryptedSdkFolderList() val mockEncryptedFolderList = mockFolderList.toEncryptedSdkFolderList()
val mockFolderViewList = listOf(createMockFolderView(number = 1)) val mockFolderViewList = listOf(createMockFolderView(number = 1))
@ -200,7 +220,10 @@ class VaultRepositoryTest {
vaultDiskSource.getFolders(userId = MOCK_USER_STATE.activeUserId) vaultDiskSource.getFolders(userId = MOCK_USER_STATE.activeUserId)
} returns mutableFoldersStateFlow } returns mutableFoldersStateFlow
coEvery { coEvery {
vaultSdkSource.decryptFolderList(mockEncryptedFolderList) vaultSdkSource.decryptFolderList(
userId = userId,
folderList = mockEncryptedFolderList,
)
} returns mockFolderViewList.asSuccess() } returns mockFolderViewList.asSuccess()
vaultRepository vaultRepository
@ -215,6 +238,7 @@ class VaultRepositoryTest {
@Test @Test
fun `foldersStateFlow should emit an error when decryptFolderList fails`() = runTest { fun `foldersStateFlow should emit an error when decryptFolderList fails`() = runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
val throwable = Throwable("Fail") val throwable = Throwable("Fail")
val mockFolderList = listOf(createMockFolder(number = 1)) val mockFolderList = listOf(createMockFolder(number = 1))
val mockEncryptedFolderList = mockFolderList.toEncryptedSdkFolderList() val mockEncryptedFolderList = mockFolderList.toEncryptedSdkFolderList()
@ -224,7 +248,10 @@ class VaultRepositoryTest {
vaultDiskSource.getFolders(userId = MOCK_USER_STATE.activeUserId) vaultDiskSource.getFolders(userId = MOCK_USER_STATE.activeUserId)
} returns mutableFoldersStateFlow } returns mutableFoldersStateFlow
coEvery { coEvery {
vaultSdkSource.decryptFolderList(mockEncryptedFolderList) vaultSdkSource.decryptFolderList(
userId = userId,
folderList = mockEncryptedFolderList,
)
} returns throwable.asFailure() } returns throwable.asFailure()
vaultRepository vaultRepository
@ -240,6 +267,7 @@ class VaultRepositoryTest {
fun `sendDataStateFlow should emit decrypted list of sends when decryptSendsList succeeds`() = fun `sendDataStateFlow should emit decrypted list of sends when decryptSendsList succeeds`() =
runTest { runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
val mockSendList = listOf(createMockSend(number = 1)) val mockSendList = listOf(createMockSend(number = 1))
val mockEncryptedSendList = mockSendList.toEncryptedSdkSendList() val mockEncryptedSendList = mockSendList.toEncryptedSdkSendList()
val mockSendViewList = listOf(createMockSendView(number = 1)) val mockSendViewList = listOf(createMockSendView(number = 1))
@ -249,7 +277,10 @@ class VaultRepositoryTest {
vaultDiskSource.getSends(userId = MOCK_USER_STATE.activeUserId) vaultDiskSource.getSends(userId = MOCK_USER_STATE.activeUserId)
} returns mutableSendsStateFlow } returns mutableSendsStateFlow
coEvery { coEvery {
vaultSdkSource.decryptSendList(mockEncryptedSendList) vaultSdkSource.decryptSendList(
userId = userId,
sendList = mockEncryptedSendList,
)
} returns mockSendViewList.asSuccess() } returns mockSendViewList.asSuccess()
vaultRepository vaultRepository
@ -264,6 +295,7 @@ class VaultRepositoryTest {
@Test @Test
fun `sendDataStateFlow should emit an error when decryptSendsList fails`() = runTest { fun `sendDataStateFlow should emit an error when decryptSendsList fails`() = runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
val throwable = Throwable("Fail") val throwable = Throwable("Fail")
val mockSendList = listOf(createMockSend(number = 1)) val mockSendList = listOf(createMockSend(number = 1))
val mockEncryptedSendList = mockSendList.toEncryptedSdkSendList() val mockEncryptedSendList = mockSendList.toEncryptedSdkSendList()
@ -273,7 +305,10 @@ class VaultRepositoryTest {
vaultDiskSource.getSends(userId = MOCK_USER_STATE.activeUserId) vaultDiskSource.getSends(userId = MOCK_USER_STATE.activeUserId)
} returns mutableSendsStateFlow } returns mutableSendsStateFlow
coEvery { coEvery {
vaultSdkSource.decryptSendList(mockEncryptedSendList) vaultSdkSource.decryptSendList(
userId = userId,
sendList = mockEncryptedSendList,
)
} returns throwable.asFailure() } returns throwable.asFailure()
vaultRepository vaultRepository
@ -302,10 +337,12 @@ class VaultRepositoryTest {
fun `sync with syncService Success should unlock the vault for orgs if necessary and update AuthDiskSource and VaultDiskSource`() = fun `sync with syncService Success should unlock the vault for orgs if necessary and update AuthDiskSource and VaultDiskSource`() =
runTest { runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
val mockSyncResponse = createMockSyncResponse(number = 1) val mockSyncResponse = createMockSyncResponse(number = 1)
coEvery { syncService.sync() } returns mockSyncResponse.asSuccess() coEvery { syncService.sync() } returns mockSyncResponse.asSuccess()
coEvery { coEvery {
vaultSdkSource.initializeOrganizationCrypto( vaultSdkSource.initializeOrganizationCrypto(
userId = userId,
request = InitOrgCryptoRequest( request = InitOrgCryptoRequest(
organizationKeys = createMockOrganizationKeys(1), organizationKeys = createMockOrganizationKeys(1),
), ),
@ -356,6 +393,7 @@ class VaultRepositoryTest {
vault = mockSyncResponse, vault = mockSyncResponse,
) )
vaultSdkSource.initializeOrganizationCrypto( vaultSdkSource.initializeOrganizationCrypto(
userId = userId,
request = InitOrgCryptoRequest( request = InitOrgCryptoRequest(
organizationKeys = createMockOrganizationKeys(1), organizationKeys = createMockOrganizationKeys(1),
), ),
@ -366,6 +404,7 @@ class VaultRepositoryTest {
@Test @Test
fun `sync with syncService Failure should update DataStateFlow with an Error`() = runTest { fun `sync with syncService Failure should update DataStateFlow with an Error`() = runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
val mockException = IllegalStateException("sad") val mockException = IllegalStateException("sad")
coEvery { syncService.sync() } returns mockException.asFailure() coEvery { syncService.sync() } returns mockException.asFailure()
@ -392,6 +431,7 @@ class VaultRepositoryTest {
@Test @Test
fun `sync with syncService Failure should update vaultDataStateFlow with an Error`() = runTest { fun `sync with syncService Failure should update vaultDataStateFlow with an Error`() = runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
val mockException = IllegalStateException("sad") val mockException = IllegalStateException("sad")
coEvery { syncService.sync() } returns mockException.asFailure() coEvery { syncService.sync() } returns mockException.asFailure()
setupVaultDiskSourceFlows() setupVaultDiskSourceFlows()
@ -408,6 +448,7 @@ class VaultRepositoryTest {
@Test @Test
fun `sync with NoNetwork should update DataStateFlows to NoNetwork`() = runTest { fun `sync with NoNetwork should update DataStateFlows to NoNetwork`() = runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
coEvery { syncService.sync() } returns UnknownHostException().asFailure() coEvery { syncService.sync() } returns UnknownHostException().asFailure()
vaultRepository.sync() vaultRepository.sync()
@ -433,6 +474,7 @@ class VaultRepositoryTest {
@Test @Test
fun `sync with NoNetwork should update vaultDataStateFlow to NoNetwork`() = runTest { fun `sync with NoNetwork should update vaultDataStateFlow to NoNetwork`() = runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
coEvery { syncService.sync() } returns UnknownHostException().asFailure() coEvery { syncService.sync() } returns UnknownHostException().asFailure()
setupVaultDiskSourceFlows() setupVaultDiskSourceFlows()
@ -450,11 +492,15 @@ class VaultRepositoryTest {
fun `sync with NoNetwork data should update sendDataStateFlow to Pending and NoNetwork with data`() = fun `sync with NoNetwork data should update sendDataStateFlow to Pending and NoNetwork with data`() =
runTest { runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
coEvery { syncService.sync() } returns UnknownHostException().asFailure() coEvery { syncService.sync() } returns UnknownHostException().asFailure()
val sendsFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Send>>() val sendsFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Send>>()
setupVaultDiskSourceFlows(sendsFlow = sendsFlow) setupVaultDiskSourceFlows(sendsFlow = sendsFlow)
coEvery { coEvery {
vaultSdkSource.decryptSendList(listOf(createMockSdkSend(1))) vaultSdkSource.decryptSendList(
userId = userId,
sendList = listOf(createMockSdkSend(1)),
)
} returns listOf(createMockSendView(1)).asSuccess() } returns listOf(createMockSendView(1)).asSuccess()
vaultRepository vaultRepository
@ -501,6 +547,7 @@ class VaultRepositoryTest {
), ),
vaultRepository.vaultStateFlow.value, vaultRepository.vaultStateFlow.value,
) )
verify { vaultSdkSource.clearCrypto(userId = userId) }
} }
@Test @Test
@ -524,16 +571,19 @@ class VaultRepositoryTest {
), ),
vaultRepository.vaultStateFlow.value, vaultRepository.vaultStateFlow.value,
) )
verify { vaultSdkSource.clearCrypto(userId = userId) }
} }
@Suppress("MaxLineLength") @Suppress("MaxLineLength")
@Test @Test
fun `unlockVaultAndSyncForCurrentUser with unlockVault Success should sync and return Success`() = fun `unlockVaultAndSyncForCurrentUser with unlockVault Success should sync and return Success`() =
runTest { runTest {
val userId = "mockId-1"
val mockSyncResponse = createMockSyncResponse(number = 1) val mockSyncResponse = createMockSyncResponse(number = 1)
coEvery { syncService.sync() } returns mockSyncResponse.asSuccess() coEvery { syncService.sync() } returns mockSyncResponse.asSuccess()
coEvery { coEvery {
vaultSdkSource.initializeOrganizationCrypto( vaultSdkSource.initializeOrganizationCrypto(
userId = userId,
request = InitOrgCryptoRequest( request = InitOrgCryptoRequest(
organizationKeys = createMockOrganizationKeys(1), organizationKeys = createMockOrganizationKeys(1),
), ),
@ -546,7 +596,10 @@ class VaultRepositoryTest {
) )
} just runs } just runs
coEvery { coEvery {
vaultSdkSource.decryptSendList(listOf(createMockSdkSend(number = 1))) vaultSdkSource.decryptSendList(
userId = userId,
sendList = listOf(createMockSdkSend(number = 1)),
)
} returns listOf(createMockSendView(number = 1)).asSuccess() } returns listOf(createMockSendView(number = 1)).asSuccess()
fakeAuthDiskSource.storePrivateKey( fakeAuthDiskSource.storePrivateKey(
userId = "mockId-1", userId = "mockId-1",
@ -563,6 +616,7 @@ class VaultRepositoryTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
coEvery { coEvery {
vaultSdkSource.initializeCrypto( vaultSdkSource.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest( request = InitUserCryptoRequest(
kdfParams = Kdf.Pbkdf2(iterations = DEFAULT_PBKDF2_ITERATIONS.toUInt()), kdfParams = Kdf.Pbkdf2(iterations = DEFAULT_PBKDF2_ITERATIONS.toUInt()),
email = "email", email = "email",
@ -601,10 +655,12 @@ class VaultRepositoryTest {
@Test @Test
fun `sync should be able to be called after unlockVaultAndSyncForCurrentUser is canceled`() = fun `sync should be able to be called after unlockVaultAndSyncForCurrentUser is canceled`() =
runTest { runTest {
val userId = "mockId-1"
val mockSyncResponse = createMockSyncResponse(number = 1) val mockSyncResponse = createMockSyncResponse(number = 1)
coEvery { syncService.sync() } returns mockSyncResponse.asSuccess() coEvery { syncService.sync() } returns mockSyncResponse.asSuccess()
coEvery { coEvery {
vaultSdkSource.initializeOrganizationCrypto( vaultSdkSource.initializeOrganizationCrypto(
userId = userId,
request = InitOrgCryptoRequest( request = InitOrgCryptoRequest(
organizationKeys = createMockOrganizationKeys(1), organizationKeys = createMockOrganizationKeys(1),
), ),
@ -617,7 +673,10 @@ class VaultRepositoryTest {
) )
} just runs } just runs
coEvery { coEvery {
vaultSdkSource.decryptSendList(listOf(createMockSdkSend(number = 1))) vaultSdkSource.decryptSendList(
userId = userId,
sendList = listOf(createMockSdkSend(number = 1)),
)
} returns listOf(createMockSendView(number = 1)).asSuccess() } returns listOf(createMockSendView(number = 1)).asSuccess()
fakeAuthDiskSource.storePrivateKey( fakeAuthDiskSource.storePrivateKey(
userId = "mockId-1", userId = "mockId-1",
@ -630,6 +689,7 @@ class VaultRepositoryTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
coEvery { coEvery {
vaultSdkSource.initializeCrypto( vaultSdkSource.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest( request = InitUserCryptoRequest(
kdfParams = Kdf.Pbkdf2(iterations = DEFAULT_PBKDF2_ITERATIONS.toUInt()), kdfParams = Kdf.Pbkdf2(iterations = DEFAULT_PBKDF2_ITERATIONS.toUInt()),
email = "email", email = "email",
@ -668,8 +728,10 @@ class VaultRepositoryTest {
userKey = "mockKey-1", userKey = "mockKey-1",
) )
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
coEvery { coEvery {
vaultSdkSource.initializeCrypto( vaultSdkSource.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest( request = InitUserCryptoRequest(
kdfParams = Kdf.Pbkdf2(iterations = DEFAULT_PBKDF2_ITERATIONS.toUInt()), kdfParams = Kdf.Pbkdf2(iterations = DEFAULT_PBKDF2_ITERATIONS.toUInt()),
email = "email", email = "email",
@ -711,8 +773,10 @@ class VaultRepositoryTest {
userKey = "mockKey-1", userKey = "mockKey-1",
) )
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
coEvery { coEvery {
vaultSdkSource.initializeCrypto( vaultSdkSource.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest( request = InitUserCryptoRequest(
kdfParams = Kdf.Pbkdf2(iterations = DEFAULT_PBKDF2_ITERATIONS.toUInt()), kdfParams = Kdf.Pbkdf2(iterations = DEFAULT_PBKDF2_ITERATIONS.toUInt()),
email = "email", email = "email",
@ -767,8 +831,10 @@ class VaultRepositoryTest {
organizationKeys = createMockOrganizationKeys(number = 1), organizationKeys = createMockOrganizationKeys(number = 1),
) )
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
coEvery { coEvery {
vaultSdkSource.initializeCrypto( vaultSdkSource.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest( request = InitUserCryptoRequest(
kdfParams = Kdf.Pbkdf2(iterations = DEFAULT_PBKDF2_ITERATIONS.toUInt()), kdfParams = Kdf.Pbkdf2(iterations = DEFAULT_PBKDF2_ITERATIONS.toUInt()),
email = "email", email = "email",
@ -782,6 +848,7 @@ class VaultRepositoryTest {
} returns InitializeCryptoResult.Success.asSuccess() } returns InitializeCryptoResult.Success.asSuccess()
coEvery { coEvery {
vaultSdkSource.initializeOrganizationCrypto( vaultSdkSource.initializeOrganizationCrypto(
userId = userId,
request = InitOrgCryptoRequest( request = InitOrgCryptoRequest(
organizationKeys = createMockOrganizationKeys(1), organizationKeys = createMockOrganizationKeys(1),
), ),
@ -825,8 +892,10 @@ class VaultRepositoryTest {
userKey = "mockKey-1", userKey = "mockKey-1",
) )
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
coEvery { coEvery {
vaultSdkSource.initializeCrypto( vaultSdkSource.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest( request = InitUserCryptoRequest(
kdfParams = Kdf.Pbkdf2(iterations = DEFAULT_PBKDF2_ITERATIONS.toUInt()), kdfParams = Kdf.Pbkdf2(iterations = DEFAULT_PBKDF2_ITERATIONS.toUInt()),
email = "email", email = "email",
@ -876,8 +945,10 @@ class VaultRepositoryTest {
organizationKeys = createMockOrganizationKeys(number = 1), organizationKeys = createMockOrganizationKeys(number = 1),
) )
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
coEvery { coEvery {
vaultSdkSource.initializeCrypto( vaultSdkSource.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest( request = InitUserCryptoRequest(
kdfParams = Kdf.Pbkdf2(iterations = DEFAULT_PBKDF2_ITERATIONS.toUInt()), kdfParams = Kdf.Pbkdf2(iterations = DEFAULT_PBKDF2_ITERATIONS.toUInt()),
email = "email", email = "email",
@ -891,6 +962,7 @@ class VaultRepositoryTest {
} returns InitializeCryptoResult.Success.asSuccess() } returns InitializeCryptoResult.Success.asSuccess()
coEvery { coEvery {
vaultSdkSource.initializeOrganizationCrypto( vaultSdkSource.initializeOrganizationCrypto(
userId = userId,
request = InitOrgCryptoRequest( request = InitOrgCryptoRequest(
organizationKeys = createMockOrganizationKeys(1), organizationKeys = createMockOrganizationKeys(1),
), ),
@ -963,6 +1035,7 @@ class VaultRepositoryTest {
privateKey = "mockPrivateKey-1", privateKey = "mockPrivateKey-1",
) )
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
assertEquals( assertEquals(
VaultUnlockResult.InvalidStateError, VaultUnlockResult.InvalidStateError,
result, result,
@ -995,6 +1068,7 @@ class VaultRepositoryTest {
privateKey = null, privateKey = null,
) )
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
assertEquals( assertEquals(
VaultUnlockResult.InvalidStateError, VaultUnlockResult.InvalidStateError,
result, result,
@ -1018,6 +1092,7 @@ class VaultRepositoryTest {
val organizationKeys = mapOf("orgId1" to "orgKey1") val organizationKeys = mapOf("orgId1" to "orgKey1")
coEvery { coEvery {
vaultSdkSource.initializeCrypto( vaultSdkSource.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest( request = InitUserCryptoRequest(
kdfParams = kdf, kdfParams = kdf,
email = email, email = email,
@ -1031,6 +1106,7 @@ class VaultRepositoryTest {
} returns InitializeCryptoResult.Success.asSuccess() } returns InitializeCryptoResult.Success.asSuccess()
coEvery { coEvery {
vaultSdkSource.initializeOrganizationCrypto( vaultSdkSource.initializeOrganizationCrypto(
userId = userId,
request = InitOrgCryptoRequest(organizationKeys = organizationKeys), request = InitOrgCryptoRequest(organizationKeys = organizationKeys),
) )
} returns InitializeCryptoResult.Success.asSuccess() } returns InitializeCryptoResult.Success.asSuccess()
@ -1060,6 +1136,7 @@ class VaultRepositoryTest {
) )
coVerify(exactly = 1) { coVerify(exactly = 1) {
vaultSdkSource.initializeCrypto( vaultSdkSource.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest( request = InitUserCryptoRequest(
kdfParams = kdf, kdfParams = kdf,
email = email, email = email,
@ -1073,6 +1150,7 @@ class VaultRepositoryTest {
} }
coVerify(exactly = 1) { coVerify(exactly = 1) {
vaultSdkSource.initializeOrganizationCrypto( vaultSdkSource.initializeOrganizationCrypto(
userId = userId,
request = InitOrgCryptoRequest(organizationKeys = organizationKeys), request = InitOrgCryptoRequest(organizationKeys = organizationKeys),
) )
} }
@ -1091,6 +1169,7 @@ class VaultRepositoryTest {
val organizationKeys = mapOf("orgId1" to "orgKey1") val organizationKeys = mapOf("orgId1" to "orgKey1")
coEvery { coEvery {
vaultSdkSource.initializeCrypto( vaultSdkSource.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest( request = InitUserCryptoRequest(
kdfParams = kdf, kdfParams = kdf,
email = email, email = email,
@ -1129,6 +1208,7 @@ class VaultRepositoryTest {
) )
coVerify(exactly = 1) { coVerify(exactly = 1) {
vaultSdkSource.initializeCrypto( vaultSdkSource.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest( request = InitUserCryptoRequest(
kdfParams = kdf, kdfParams = kdf,
email = email, email = email,
@ -1155,6 +1235,7 @@ class VaultRepositoryTest {
val organizationKeys = mapOf("orgId1" to "orgKey1") val organizationKeys = mapOf("orgId1" to "orgKey1")
coEvery { coEvery {
vaultSdkSource.initializeCrypto( vaultSdkSource.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest( request = InitUserCryptoRequest(
kdfParams = kdf, kdfParams = kdf,
email = email, email = email,
@ -1168,6 +1249,7 @@ class VaultRepositoryTest {
} returns InitializeCryptoResult.Success.asSuccess() } returns InitializeCryptoResult.Success.asSuccess()
coEvery { coEvery {
vaultSdkSource.initializeOrganizationCrypto( vaultSdkSource.initializeOrganizationCrypto(
userId = userId,
request = InitOrgCryptoRequest(organizationKeys = organizationKeys), request = InitOrgCryptoRequest(organizationKeys = organizationKeys),
) )
} returns InitializeCryptoResult.AuthenticationError.asSuccess() } returns InitializeCryptoResult.AuthenticationError.asSuccess()
@ -1198,6 +1280,7 @@ class VaultRepositoryTest {
) )
coVerify(exactly = 1) { coVerify(exactly = 1) {
vaultSdkSource.initializeCrypto( vaultSdkSource.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest( request = InitUserCryptoRequest(
kdfParams = kdf, kdfParams = kdf,
email = email, email = email,
@ -1211,6 +1294,7 @@ class VaultRepositoryTest {
} }
coVerify(exactly = 1) { coVerify(exactly = 1) {
vaultSdkSource.initializeOrganizationCrypto( vaultSdkSource.initializeOrganizationCrypto(
userId = userId,
request = InitOrgCryptoRequest(organizationKeys = organizationKeys), request = InitOrgCryptoRequest(organizationKeys = organizationKeys),
) )
} }
@ -1228,6 +1312,7 @@ class VaultRepositoryTest {
val organizationKeys = mapOf("orgId1" to "orgKey1") val organizationKeys = mapOf("orgId1" to "orgKey1")
coEvery { coEvery {
vaultSdkSource.initializeCrypto( vaultSdkSource.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest( request = InitUserCryptoRequest(
kdfParams = kdf, kdfParams = kdf,
email = email, email = email,
@ -1265,6 +1350,7 @@ class VaultRepositoryTest {
) )
coVerify(exactly = 1) { coVerify(exactly = 1) {
vaultSdkSource.initializeCrypto( vaultSdkSource.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest( request = InitUserCryptoRequest(
kdfParams = kdf, kdfParams = kdf,
email = email, email = email,
@ -1290,6 +1376,7 @@ class VaultRepositoryTest {
val organizationKeys = mapOf("orgId1" to "orgKey1") val organizationKeys = mapOf("orgId1" to "orgKey1")
coEvery { coEvery {
vaultSdkSource.initializeCrypto( vaultSdkSource.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest( request = InitUserCryptoRequest(
kdfParams = kdf, kdfParams = kdf,
email = email, email = email,
@ -1303,6 +1390,7 @@ class VaultRepositoryTest {
} returns InitializeCryptoResult.Success.asSuccess() } returns InitializeCryptoResult.Success.asSuccess()
coEvery { coEvery {
vaultSdkSource.initializeOrganizationCrypto( vaultSdkSource.initializeOrganizationCrypto(
userId = userId,
request = InitOrgCryptoRequest(organizationKeys = organizationKeys), request = InitOrgCryptoRequest(organizationKeys = organizationKeys),
) )
} returns Throwable("Fail").asFailure() } returns Throwable("Fail").asFailure()
@ -1332,6 +1420,7 @@ class VaultRepositoryTest {
) )
coVerify(exactly = 1) { coVerify(exactly = 1) {
vaultSdkSource.initializeCrypto( vaultSdkSource.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest( request = InitUserCryptoRequest(
kdfParams = kdf, kdfParams = kdf,
email = email, email = email,
@ -1345,6 +1434,7 @@ class VaultRepositoryTest {
} }
coVerify(exactly = 1) { coVerify(exactly = 1) {
vaultSdkSource.initializeOrganizationCrypto( vaultSdkSource.initializeOrganizationCrypto(
userId = userId,
request = InitOrgCryptoRequest(organizationKeys = organizationKeys), request = InitOrgCryptoRequest(organizationKeys = organizationKeys),
) )
} }
@ -1361,6 +1451,7 @@ class VaultRepositoryTest {
val organizationKeys = null val organizationKeys = null
coEvery { coEvery {
vaultSdkSource.initializeCrypto( vaultSdkSource.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest( request = InitUserCryptoRequest(
kdfParams = kdf, kdfParams = kdf,
email = email, email = email,
@ -1392,6 +1483,7 @@ class VaultRepositoryTest {
coVerify(exactly = 0) { syncService.sync() } coVerify(exactly = 0) { syncService.sync() }
coVerify(exactly = 1) { coVerify(exactly = 1) {
vaultSdkSource.initializeCrypto( vaultSdkSource.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest( request = InitUserCryptoRequest(
kdfParams = kdf, kdfParams = kdf,
email = email, email = email,
@ -1408,17 +1500,30 @@ class VaultRepositoryTest {
@Test @Test
fun `clearUnlockedData should update the vaultDataStateFlow to Loading`() = runTest { fun `clearUnlockedData should update the vaultDataStateFlow to Loading`() = runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
coEvery { coEvery {
vaultSdkSource.decryptCipherList(listOf(createMockSdkCipher(1))) vaultSdkSource.decryptCipherList(
userId = userId,
cipherList = listOf(createMockSdkCipher(1)),
)
} returns listOf(createMockCipherView(number = 1)).asSuccess() } returns listOf(createMockCipherView(number = 1)).asSuccess()
coEvery { coEvery {
vaultSdkSource.decryptFolderList(listOf(createMockSdkFolder(1))) vaultSdkSource.decryptFolderList(
userId = userId,
folderList = listOf(createMockSdkFolder(1)),
)
} returns listOf(createMockFolderView(number = 1)).asSuccess() } returns listOf(createMockFolderView(number = 1)).asSuccess()
coEvery { coEvery {
vaultSdkSource.decryptCollectionList(listOf(createMockSdkCollection(1))) vaultSdkSource.decryptCollectionList(
userId = userId,
collectionList = listOf(createMockSdkCollection(1)),
)
} returns listOf(createMockCollectionView(number = 1)).asSuccess() } returns listOf(createMockCollectionView(number = 1)).asSuccess()
coEvery { coEvery {
vaultSdkSource.decryptSendList(listOf(createMockSdkSend(number = 1))) vaultSdkSource.decryptSendList(
userId = userId,
sendList = listOf(createMockSdkSend(number = 1)),
)
} returns listOf(createMockSendView(number = 1)).asSuccess() } returns listOf(createMockSendView(number = 1)).asSuccess()
val ciphersFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Cipher>>() val ciphersFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Cipher>>()
val collectionsFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Collection>>() val collectionsFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Collection>>()
@ -1456,8 +1561,12 @@ class VaultRepositoryTest {
@Test @Test
fun `clearUnlockedData should update the sendDataStateFlow to Loading`() = runTest { fun `clearUnlockedData should update the sendDataStateFlow to Loading`() = runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
coEvery { coEvery {
vaultSdkSource.decryptSendList(listOf(createMockSdkSend(number = 1))) vaultSdkSource.decryptSendList(
userId = userId,
sendList = listOf(createMockSdkSend(number = 1)),
)
} returns listOf(createMockSendView(number = 1)).asSuccess() } returns listOf(createMockSendView(number = 1)).asSuccess()
val sendsFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Send>>() val sendsFlow = bufferedMutableSharedFlow<List<SyncResponseJson.Send>>()
setupVaultDiskSourceFlows(sendsFlow = sendsFlow) setupVaultDiskSourceFlows(sendsFlow = sendsFlow)
@ -1486,6 +1595,7 @@ class VaultRepositoryTest {
val folderIdString = "mockId-$folderId" val folderIdString = "mockId-$folderId"
val throwable = Throwable("Fail") val throwable = Throwable("Fail")
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
coEvery { syncService.sync() } returns throwable.asFailure() coEvery { syncService.sync() } returns throwable.asFailure()
setupVaultDiskSourceFlows() setupVaultDiskSourceFlows()
@ -1506,6 +1616,7 @@ class VaultRepositoryTest {
val itemId = 1234 val itemId = 1234
val itemIdString = "mockId-$itemId" val itemIdString = "mockId-$itemId"
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
coEvery { syncService.sync() } returns UnknownHostException().asFailure() coEvery { syncService.sync() } returns UnknownHostException().asFailure()
setupVaultDiskSourceFlows() setupVaultDiskSourceFlows()
@ -1526,6 +1637,7 @@ class VaultRepositoryTest {
val folderId = 1234 val folderId = 1234
val folderIdString = "mockId-$folderId" val folderIdString = "mockId-$folderId"
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
coEvery { syncService.sync() } returns UnknownHostException().asFailure() coEvery { syncService.sync() } returns UnknownHostException().asFailure()
setupVaultDiskSourceFlows() setupVaultDiskSourceFlows()
@ -1546,6 +1658,7 @@ class VaultRepositoryTest {
val folderIdString = "mockId-$folderId" val folderIdString = "mockId-$folderId"
val throwable = Throwable("Fail") val throwable = Throwable("Fail")
fakeAuthDiskSource.userState = MOCK_USER_STATE fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
coEvery { syncService.sync() } returns throwable.asFailure() coEvery { syncService.sync() } returns throwable.asFailure()
setupVaultDiskSourceFlows() setupVaultDiskSourceFlows()
@ -1563,9 +1676,14 @@ class VaultRepositoryTest {
@Test @Test
fun `createCipher with encryptCipher failure should return CreateCipherResult failure`() = fun `createCipher with encryptCipher failure should return CreateCipherResult failure`() =
runTest { runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
val mockCipherView = createMockCipherView(number = 1) val mockCipherView = createMockCipherView(number = 1)
coEvery { coEvery {
vaultSdkSource.encryptCipher(cipherView = mockCipherView) vaultSdkSource.encryptCipher(
userId = userId,
cipherView = mockCipherView,
)
} returns IllegalStateException().asFailure() } returns IllegalStateException().asFailure()
val result = vaultRepository.createCipher(cipherView = mockCipherView) val result = vaultRepository.createCipher(cipherView = mockCipherView)
@ -1580,9 +1698,14 @@ class VaultRepositoryTest {
@Suppress("MaxLineLength") @Suppress("MaxLineLength")
fun `createCipher with ciphersService createCipher failure should return CreateCipherResult failure`() = fun `createCipher with ciphersService createCipher failure should return CreateCipherResult failure`() =
runTest { runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
val mockCipherView = createMockCipherView(number = 1) val mockCipherView = createMockCipherView(number = 1)
coEvery { coEvery {
vaultSdkSource.encryptCipher(cipherView = mockCipherView) vaultSdkSource.encryptCipher(
userId = userId,
cipherView = mockCipherView,
)
} returns createMockSdkCipher(number = 1).asSuccess() } returns createMockSdkCipher(number = 1).asSuccess()
coEvery { coEvery {
ciphersService.createCipher( ciphersService.createCipher(
@ -1602,9 +1725,14 @@ class VaultRepositoryTest {
@Suppress("MaxLineLength") @Suppress("MaxLineLength")
fun `createCipher with ciphersService createCipher success should return CreateCipherResult success`() = fun `createCipher with ciphersService createCipher success should return CreateCipherResult success`() =
runTest { runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
val mockCipherView = createMockCipherView(number = 1) val mockCipherView = createMockCipherView(number = 1)
coEvery { coEvery {
vaultSdkSource.encryptCipher(cipherView = mockCipherView) vaultSdkSource.encryptCipher(
userId = userId,
cipherView = mockCipherView,
)
} returns createMockSdkCipher(number = 1).asSuccess() } returns createMockSdkCipher(number = 1).asSuccess()
coEvery { coEvery {
ciphersService.createCipher( ciphersService.createCipher(
@ -1614,15 +1742,25 @@ class VaultRepositoryTest {
coEvery { coEvery {
syncService.sync() syncService.sync()
} returns Result.success(createMockSyncResponse(1)) } returns Result.success(createMockSyncResponse(1))
coEvery {
vaultDiskSource.replaceVaultData(
userId = userId,
vault = createMockSyncResponse(1),
)
} just runs
coEvery { coEvery {
vaultSdkSource.initializeOrganizationCrypto( vaultSdkSource.initializeOrganizationCrypto(
userId = userId,
request = InitOrgCryptoRequest( request = InitOrgCryptoRequest(
organizationKeys = createMockOrganizationKeys(1), organizationKeys = createMockOrganizationKeys(1),
), ),
) )
} returns InitializeCryptoResult.Success.asSuccess() } returns InitializeCryptoResult.Success.asSuccess()
coEvery { coEvery {
vaultSdkSource.decryptSendList(listOf(createMockSdkSend(1))) vaultSdkSource.decryptSendList(
userId = userId,
sendList = listOf(createMockSdkSend(1)),
)
} returns listOf(createMockSendView(1)).asSuccess() } returns listOf(createMockSendView(1)).asSuccess()
val result = vaultRepository.createCipher(cipherView = mockCipherView) val result = vaultRepository.createCipher(cipherView = mockCipherView)
@ -1636,10 +1774,15 @@ class VaultRepositoryTest {
@Test @Test
fun `updateCipher with encryptCipher failure should return UpdateCipherResult failure`() = fun `updateCipher with encryptCipher failure should return UpdateCipherResult failure`() =
runTest { runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
val cipherId = "cipherId1234" val cipherId = "cipherId1234"
val mockCipherView = createMockCipherView(number = 1) val mockCipherView = createMockCipherView(number = 1)
coEvery { coEvery {
vaultSdkSource.encryptCipher(cipherView = mockCipherView) vaultSdkSource.encryptCipher(
userId = userId,
cipherView = mockCipherView,
)
} returns IllegalStateException().asFailure() } returns IllegalStateException().asFailure()
val result = vaultRepository.updateCipher( val result = vaultRepository.updateCipher(
@ -1654,10 +1797,15 @@ class VaultRepositoryTest {
@Suppress("MaxLineLength") @Suppress("MaxLineLength")
fun `updateCipher with ciphersService updateCipher failure should return UpdateCipherResult Error with a null message`() = fun `updateCipher with ciphersService updateCipher failure should return UpdateCipherResult Error with a null message`() =
runTest { runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
val cipherId = "cipherId1234" val cipherId = "cipherId1234"
val mockCipherView = createMockCipherView(number = 1) val mockCipherView = createMockCipherView(number = 1)
coEvery { coEvery {
vaultSdkSource.encryptCipher(cipherView = mockCipherView) vaultSdkSource.encryptCipher(
userId = userId,
cipherView = mockCipherView,
)
} returns createMockSdkCipher(number = 1).asSuccess() } returns createMockSdkCipher(number = 1).asSuccess()
coEvery { coEvery {
ciphersService.updateCipher( ciphersService.updateCipher(
@ -1678,10 +1826,15 @@ class VaultRepositoryTest {
@Suppress("MaxLineLength") @Suppress("MaxLineLength")
fun `updateCipher with ciphersService updateCipher Invalid response should return UpdateCipherResult Error with a non-null message`() = fun `updateCipher with ciphersService updateCipher Invalid response should return UpdateCipherResult Error with a non-null message`() =
runTest { runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
val cipherId = "cipherId1234" val cipherId = "cipherId1234"
val mockCipherView = createMockCipherView(number = 1) val mockCipherView = createMockCipherView(number = 1)
coEvery { coEvery {
vaultSdkSource.encryptCipher(cipherView = mockCipherView) vaultSdkSource.encryptCipher(
userId = userId,
cipherView = mockCipherView,
)
} returns createMockSdkCipher(number = 1).asSuccess() } returns createMockSdkCipher(number = 1).asSuccess()
coEvery { coEvery {
ciphersService.updateCipher( ciphersService.updateCipher(
@ -1712,10 +1865,15 @@ class VaultRepositoryTest {
@Suppress("MaxLineLength") @Suppress("MaxLineLength")
fun `updateCipher with ciphersService updateCipher Success response should return UpdateCipherResult success`() = fun `updateCipher with ciphersService updateCipher Success response should return UpdateCipherResult success`() =
runTest { runTest {
fakeAuthDiskSource.userState = MOCK_USER_STATE
val userId = "mockId-1"
val cipherId = "cipherId1234" val cipherId = "cipherId1234"
val mockCipherView = createMockCipherView(number = 1) val mockCipherView = createMockCipherView(number = 1)
coEvery { coEvery {
vaultSdkSource.encryptCipher(cipherView = mockCipherView) vaultSdkSource.encryptCipher(
userId = userId,
cipherView = mockCipherView,
)
} returns createMockSdkCipher(number = 1).asSuccess() } returns createMockSdkCipher(number = 1).asSuccess()
coEvery { coEvery {
ciphersService.updateCipher( ciphersService.updateCipher(
@ -1728,15 +1886,25 @@ class VaultRepositoryTest {
coEvery { coEvery {
syncService.sync() syncService.sync()
} returns Result.success(createMockSyncResponse(1)) } returns Result.success(createMockSyncResponse(1))
coEvery {
vaultDiskSource.replaceVaultData(
userId = userId,
vault = createMockSyncResponse(1),
)
} just runs
coEvery { coEvery {
vaultSdkSource.initializeOrganizationCrypto( vaultSdkSource.initializeOrganizationCrypto(
userId = userId,
request = InitOrgCryptoRequest( request = InitOrgCryptoRequest(
organizationKeys = createMockOrganizationKeys(1), organizationKeys = createMockOrganizationKeys(1),
), ),
) )
} returns InitializeCryptoResult.Success.asSuccess() } returns InitializeCryptoResult.Success.asSuccess()
coEvery { coEvery {
vaultSdkSource.decryptSendList(listOf(createMockSdkSend(1))) vaultSdkSource.decryptSendList(
userId = userId,
sendList = listOf(createMockSdkSend(1)),
)
} returns listOf(createMockSendView(1)).asSuccess() } returns listOf(createMockSendView(1)).asSuccess()
val result = vaultRepository.updateCipher( val result = vaultRepository.updateCipher(
@ -1779,6 +1947,7 @@ class VaultRepositoryTest {
val organizationKeys = null val organizationKeys = null
coEvery { coEvery {
vaultSdkSource.initializeCrypto( vaultSdkSource.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest( request = InitUserCryptoRequest(
kdfParams = kdf, kdfParams = kdf,
email = email, email = email,
@ -1804,6 +1973,7 @@ class VaultRepositoryTest {
assertEquals(VaultUnlockResult.Success, result) assertEquals(VaultUnlockResult.Success, result)
coVerify(exactly = 1) { coVerify(exactly = 1) {
vaultSdkSource.initializeCrypto( vaultSdkSource.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest( request = InitUserCryptoRequest(
kdfParams = kdf, kdfParams = kdf,
email = email, email = email,