diff --git a/app/src/main/java/com/x8bit/bitwarden/data/auth/datasource/disk/AuthDiskSource.kt b/app/src/main/java/com/x8bit/bitwarden/data/auth/datasource/disk/AuthDiskSource.kt index af7064375..99b31d841 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/auth/datasource/disk/AuthDiskSource.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/auth/datasource/disk/AuthDiskSource.kt @@ -1,6 +1,7 @@ package com.x8bit.bitwarden.data.auth.datasource.disk import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson +import com.x8bit.bitwarden.data.vault.datasource.network.model.SyncResponseJson import kotlinx.coroutines.flow.Flow /** @@ -63,4 +64,22 @@ interface AuthDiskSource { userId: String, organizationKeys: Map<String, String>?, ) + + /** + * Gets the organization data for the given [userId]. + */ + fun getOrganizations(userId: String): List<SyncResponseJson.Profile.Organization>? + + /** + * Emits updates that track [getOrganizations]. This will replay the last known value, if any. + */ + fun getOrganizationsFlow(userId: String): Flow<List<SyncResponseJson.Profile.Organization>?> + + /** + * Stores the organization data for the given [userId]. + */ + fun storeOrganizations( + userId: String, + organizations: List<SyncResponseJson.Profile.Organization>?, + ) } diff --git a/app/src/main/java/com/x8bit/bitwarden/data/auth/datasource/disk/AuthDiskSourceImpl.kt b/app/src/main/java/com/x8bit/bitwarden/data/auth/datasource/disk/AuthDiskSourceImpl.kt index 865afd7d3..767f8a86c 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/auth/datasource/disk/AuthDiskSourceImpl.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/auth/datasource/disk/AuthDiskSourceImpl.kt @@ -4,6 +4,7 @@ import android.content.SharedPreferences import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson import com.x8bit.bitwarden.data.platform.datasource.disk.BaseDiskSource import com.x8bit.bitwarden.data.platform.datasource.disk.BaseDiskSource.Companion.BASE_KEY +import com.x8bit.bitwarden.data.vault.datasource.network.model.SyncResponseJson import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.onSubscription @@ -16,6 +17,7 @@ private const val REMEMBERED_EMAIL_ADDRESS_KEY = "$BASE_KEY:rememberedEmail" private const val STATE_KEY = "$BASE_KEY:state" private const val MASTER_KEY_ENCRYPTION_USER_KEY = "$BASE_KEY:masterKeyEncryptedUserKey" private const val MASTER_KEY_ENCRYPTION_PRIVATE_KEY = "$BASE_KEY:encPrivateKey" +private const val ORGANIZATIONS_KEY = "$BASE_KEY:organizations" private const val ORGANIZATION_KEYS_KEY = "$BASE_KEY:encOrgKeys" /** @@ -26,6 +28,12 @@ class AuthDiskSourceImpl( private val json: Json, ) : BaseDiskSource(sharedPreferences = sharedPreferences), AuthDiskSource { + private val mutableOrganizationsFlow = + MutableSharedFlow<List<SyncResponseJson.Profile.Organization>?>( + replay = 1, + extraBufferCapacity = Int.MAX_VALUE, + ) + override val uniqueAppId: String get() = getString(key = UNIQUE_APP_ID_KEY) ?: generateAndStoreUniqueAppId() @@ -91,6 +99,29 @@ class AuthDiskSourceImpl( ) } + override fun getOrganizations( + userId: String, + ): List<SyncResponseJson.Profile.Organization>? = + getString(key = "${ORGANIZATIONS_KEY}_$userId") + ?.let { json.decodeFromString(it) } + + override fun getOrganizationsFlow( + userId: String, + ): Flow<List<SyncResponseJson.Profile.Organization>?> = + mutableOrganizationsFlow + .onSubscription { emit(getOrganizations(userId = userId)) } + + override fun storeOrganizations( + userId: String, + organizations: List<SyncResponseJson.Profile.Organization>?, + ) { + putString( + key = "${ORGANIZATIONS_KEY}_$userId", + value = organizations?.let { json.encodeToString(it) }, + ) + mutableOrganizationsFlow.tryEmit(organizations) + } + private fun generateAndStoreUniqueAppId(): String = UUID .randomUUID() diff --git a/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryImpl.kt b/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryImpl.kt index 4a31a4e73..c6d126a5d 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryImpl.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryImpl.kt @@ -245,6 +245,7 @@ class AuthRepositoryImpl constructor( storeUserKey(userId = userId, userKey = null) storePrivateKey(userId = userId, privateKey = null) storeOrganizationKeys(userId = userId, organizationKeys = null) + storeOrganizations(userId = userId, organizations = null) } // Check if there is a new active user diff --git a/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryImpl.kt b/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryImpl.kt index 52f7ff42f..a7e8bdefa 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryImpl.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryImpl.kt @@ -191,7 +191,7 @@ class VaultRepositoryImpl( ) unlockVaultForOrganizationsIfNecessary(syncResponse = syncResponse) - storeKeys(syncResponse = syncResponse) + storeProfileData(syncResponse = syncResponse) vaultDiskSource.replaceVaultData(userId = userId, vault = syncResponse) decryptSendsAndUpdateSendDataState(sendList = syncResponse.sends) }, @@ -403,7 +403,7 @@ class VaultRepositoryImpl( } } - private fun storeKeys( + private fun storeProfileData( syncResponse: SyncResponseJson, ) { val profile = syncResponse.profile @@ -426,6 +426,10 @@ class VaultRepositoryImpl( .filter { it.key != null } .associate { it.id to requireNotNull(it.key) }, ) + storeOrganizations( + userId = profile.id, + organizations = syncResponse.profile.organizations, + ) } } diff --git a/app/src/test/java/com/x8bit/bitwarden/data/auth/datasource/disk/AuthDiskSourceTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/auth/datasource/disk/AuthDiskSourceTest.kt index f3a4ab0ac..cc2161351 100644 --- a/app/src/test/java/com/x8bit/bitwarden/data/auth/datasource/disk/AuthDiskSourceTest.kt +++ b/app/src/test/java/com/x8bit/bitwarden/data/auth/datasource/disk/AuthDiskSourceTest.kt @@ -10,9 +10,11 @@ import com.x8bit.bitwarden.data.auth.datasource.network.model.KeyConnectorUserDe import com.x8bit.bitwarden.data.auth.datasource.network.model.TrustedDeviceUserDecryptionOptionsJson import com.x8bit.bitwarden.data.auth.datasource.network.model.UserDecryptionOptionsJson import com.x8bit.bitwarden.data.platform.base.FakeSharedPreferences +import com.x8bit.bitwarden.data.platform.datasource.network.di.PlatformNetworkModule +import com.x8bit.bitwarden.data.vault.datasource.network.model.createMockOrganization import kotlinx.coroutines.test.runTest -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.json.Json +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.encodeToJsonElement import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertNull import org.junit.jupiter.api.Test @@ -20,11 +22,7 @@ import org.junit.jupiter.api.Test class AuthDiskSourceTest { private val fakeSharedPreferences = FakeSharedPreferences() - @OptIn(ExperimentalSerializationApi::class) - private val json = Json { - ignoreUnknownKeys = true - explicitNulls = false - } + private val json = PlatformNetworkModule.providesJson() private val authDiskSource = AuthDiskSourceImpl( sharedPreferences = fakeSharedPreferences, @@ -250,6 +248,71 @@ class AuthDiskSourceTest { json.parseToJsonElement(requireNotNull(actual)), ) } + + @Test + fun `getOrganizations should pull from SharedPreferences`() { + val organizationsBaseKey = "bwPreferencesStorage:organizations" + val mockUserId = "mockUserId" + val mockOrganizations = listOf( + createMockOrganization(0), + createMockOrganization(1), + ) + fakeSharedPreferences + .edit() + .putString( + "${organizationsBaseKey}_$mockUserId", + json.encodeToString(mockOrganizations), + ) + .apply() + val actual = authDiskSource.getOrganizations(userId = mockUserId) + assertEquals( + mockOrganizations, + actual, + ) + } + + @Test + fun `getOrganizationsFlow should react to changes in getOrganizations`() = runTest { + val mockUserId = "mockUserId" + val mockOrganizations = listOf( + createMockOrganization(0), + createMockOrganization(1), + ) + authDiskSource.getOrganizationsFlow(userId = mockUserId).test { + // The initial values of the Flow and the property are in sync + assertNull(authDiskSource.getOrganizations(userId = mockUserId)) + assertNull(awaitItem()) + + // Updating the repository updates shared preferences + authDiskSource.storeOrganizations( + userId = mockUserId, + organizations = mockOrganizations, + ) + assertEquals(mockOrganizations, awaitItem()) + } + } + + @Test + fun `storeOrganizations should update SharedPreferences`() { + val organizationsBaseKey = "bwPreferencesStorage:organizations" + val mockUserId = "mockUserId" + val mockOrganizations = listOf( + createMockOrganization(0), + createMockOrganization(1), + ) + authDiskSource.storeOrganizations( + userId = mockUserId, + organizations = mockOrganizations, + ) + val actual = fakeSharedPreferences.getString( + "${organizationsBaseKey}_$mockUserId", + null, + ) + assertEquals( + json.encodeToJsonElement(mockOrganizations), + json.parseToJsonElement(requireNotNull(actual)), + ) + } } private const val USER_STATE_JSON = """ diff --git a/app/src/test/java/com/x8bit/bitwarden/data/auth/datasource/disk/util/FakeAuthDiskSource.kt b/app/src/test/java/com/x8bit/bitwarden/data/auth/datasource/disk/util/FakeAuthDiskSource.kt index 313980112..00e0ccfe8 100644 --- a/app/src/test/java/com/x8bit/bitwarden/data/auth/datasource/disk/util/FakeAuthDiskSource.kt +++ b/app/src/test/java/com/x8bit/bitwarden/data/auth/datasource/disk/util/FakeAuthDiskSource.kt @@ -2,6 +2,7 @@ package com.x8bit.bitwarden.data.auth.datasource.disk.util import com.x8bit.bitwarden.data.auth.datasource.disk.AuthDiskSource import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson +import com.x8bit.bitwarden.data.vault.datasource.network.model.SyncResponseJson import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.onSubscription @@ -13,6 +14,8 @@ class FakeAuthDiskSource : AuthDiskSource { override var rememberedEmailAddress: String? = null + private val mutableOrganizationsFlowMap = + mutableMapOf<String, MutableSharedFlow<List<SyncResponseJson.Profile.Organization>?>>() private val mutableUserStateFlow = MutableSharedFlow<UserStateJson?>( replay = 1, @@ -21,6 +24,8 @@ class FakeAuthDiskSource : AuthDiskSource { private val storedUserKeys = mutableMapOf<String, String?>() private val storedPrivateKeys = mutableMapOf<String, String?>() + private val storedOrganizations = + mutableMapOf<String, List<SyncResponseJson.Profile.Organization>?>() private val storedOrganizationKeys = mutableMapOf<String, Map<String, String>?>() override var userState: UserStateJson? = null @@ -55,6 +60,23 @@ class FakeAuthDiskSource : AuthDiskSource { storedOrganizationKeys[userId] = organizationKeys } + override fun getOrganizations( + userId: String, + ): List<SyncResponseJson.Profile.Organization>? = storedOrganizations[userId] + + override fun getOrganizationsFlow( + userId: String, + ): Flow<List<SyncResponseJson.Profile.Organization>?> = + getMutableOrganizationsFlow(userId).onSubscription { emit(getOrganizations(userId)) } + + override fun storeOrganizations( + userId: String, + organizations: List<SyncResponseJson.Profile.Organization>?, + ) { + storedOrganizations[userId] = organizations + getMutableOrganizationsFlow(userId = userId).tryEmit(organizations) + } + /** * Assert that the given [userState] matches the currently tracked value. */ @@ -82,4 +104,28 @@ class FakeAuthDiskSource : AuthDiskSource { fun assertOrganizationKeys(userId: String, organizationKeys: Map<String, String>?) { assertEquals(organizationKeys, storedOrganizationKeys[userId]) } + + /** + * Assert that the [organizations] were stored successfully using the [userId]. + */ + fun assertOrganizations( + userId: String, + organizations: List<SyncResponseJson.Profile.Organization>?, + ) { + assertEquals(organizations, storedOrganizations[userId]) + } + + //region Private helper functions + + private fun getMutableOrganizationsFlow( + userId: String, + ): MutableSharedFlow<List<SyncResponseJson.Profile.Organization>?> = + mutableOrganizationsFlowMap.getOrPut(userId) { + MutableSharedFlow( + replay = 1, + extraBufferCapacity = Int.MAX_VALUE, + ) + } + + //endregion Private helper functions } diff --git a/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryTest.kt index fd426d35e..6d41aacd0 100644 --- a/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryTest.kt +++ b/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryTest.kt @@ -44,6 +44,7 @@ import com.x8bit.bitwarden.data.platform.repository.model.Environment import com.x8bit.bitwarden.data.platform.repository.util.FakeEnvironmentRepository import com.x8bit.bitwarden.data.platform.util.asFailure import com.x8bit.bitwarden.data.platform.util.asSuccess +import com.x8bit.bitwarden.data.vault.datasource.network.model.createMockOrganization import com.x8bit.bitwarden.data.vault.repository.VaultRepository import com.x8bit.bitwarden.data.vault.repository.model.VaultState import com.x8bit.bitwarden.data.vault.repository.model.VaultUnlockResult @@ -927,8 +928,9 @@ class AuthRepositoryTest { } } + @Suppress("MaxLineLength") @Test - fun `logout for single account should clear the access token and stored keys`() = runTest { + fun `logout for single account should clear the access token and profile data`() = runTest { // First login: val successResponse = GET_TOKEN_RESPONSE_SUCCESS coEvery { @@ -973,6 +975,10 @@ class AuthRepositoryTest { userId = USER_ID_1, organizationKeys = ORGANIZATION_KEYS, ) + storeOrganizations( + userId = USER_ID_1, + organizations = ORGANIZATIONS, + ) } repository.login(email = EMAIL, password = PASSWORD, captchaToken = null) @@ -1000,6 +1006,10 @@ class AuthRepositoryTest { userId = USER_ID_1, organizationKeys = null, ) + fakeAuthDiskSource.assertOrganizations( + userId = USER_ID_1, + organizations = null, + ) verify { vaultRepository.deleteVaultData(userId = USER_ID_1) } verify { vaultRepository.clearUnlockedData() } verify { vaultRepository.lockVaultIfNecessary(userId = USER_ID_1) } @@ -1356,6 +1366,7 @@ class AuthRepositoryTest { private const val USER_ID_2 = "b9d32ec0-6497-4582-9798-b350f53bfa02" private const val USER_ID_3 = "3816ef34-0747-4133-9b7a-ba35d3768a68" private val ORGANIZATION_KEYS = mapOf("organizationId1" to "organizationKey1") + private val ORGANIZATIONS = listOf(createMockOrganization(number = 0)) private val PRE_LOGIN_SUCCESS = PreLoginResponseJson( kdfParams = PreLoginResponseJson.KdfParams.Pbkdf2(iterations = 1u), ) diff --git a/app/src/test/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryTest.kt index eaa23540d..ed8a7bfb4 100644 --- a/app/src/test/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryTest.kt +++ b/app/src/test/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryTest.kt @@ -25,6 +25,7 @@ import com.x8bit.bitwarden.data.vault.datasource.network.model.createMockCipher import com.x8bit.bitwarden.data.vault.datasource.network.model.createMockCipherJsonRequest import com.x8bit.bitwarden.data.vault.datasource.network.model.createMockCollection import com.x8bit.bitwarden.data.vault.datasource.network.model.createMockFolder +import com.x8bit.bitwarden.data.vault.datasource.network.model.createMockOrganization import com.x8bit.bitwarden.data.vault.datasource.network.model.createMockOrganizationKeys import com.x8bit.bitwarden.data.vault.datasource.network.model.createMockSyncResponse import com.x8bit.bitwarden.data.vault.datasource.network.service.CiphersService @@ -310,6 +311,10 @@ class VaultRepositoryTest { userId = "mockId-1", organizationKeys = mapOf("mockId-1" to "mockKey-1"), ) + fakeAuthDiskSource.assertOrganizations( + userId = "mockId-1", + organizations = listOf(createMockOrganization(number = 1)), + ) assertEquals( DataState.Loaded( data = SendData(