BIT-2070: Enable individual Cipher Encryption for SDK (#1148)

This commit is contained in:
Ramsey Smith 2024-03-18 12:41:51 -06:00 committed by Álison Fernandes
parent 3c715d39d6
commit 0d2467d8d2
13 changed files with 144 additions and 43 deletions

View file

@ -12,6 +12,10 @@ import com.bitwarden.sdk.ClientPlatform
import com.x8bit.bitwarden.data.auth.datasource.sdk.model.PasswordStrength import com.x8bit.bitwarden.data.auth.datasource.sdk.model.PasswordStrength
import com.x8bit.bitwarden.data.auth.datasource.sdk.util.toPasswordStrengthOrNull import com.x8bit.bitwarden.data.auth.datasource.sdk.util.toPasswordStrengthOrNull
import com.x8bit.bitwarden.data.auth.datasource.sdk.util.toUByte import com.x8bit.bitwarden.data.auth.datasource.sdk.util.toUByte
import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager
import com.x8bit.bitwarden.data.vault.datasource.sdk.BitwardenFeatureFlagManager
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.launch
/** /**
* Primary implementation of [AuthSdkSource] that serves as a convenience wrapper around a * Primary implementation of [AuthSdkSource] that serves as a convenience wrapper around a
@ -20,8 +24,18 @@ import com.x8bit.bitwarden.data.auth.datasource.sdk.util.toUByte
class AuthSdkSourceImpl( class AuthSdkSourceImpl(
private val clientAuth: ClientAuth, private val clientAuth: ClientAuth,
private val clientPlatform: ClientPlatform, private val clientPlatform: ClientPlatform,
dispatcherManager: DispatcherManager,
featureFlagManager: BitwardenFeatureFlagManager,
) : AuthSdkSource { ) : AuthSdkSource {
private val ioScope = CoroutineScope(dispatcherManager.io)
init {
ioScope.launch {
clientPlatform.loadFlags(featureFlagManager.featureFlags)
}
}
override suspend fun getTrustDevice(): Result<TrustDeviceResponse> = runCatching { override suspend fun getTrustDevice(): Result<TrustDeviceResponse> = runCatching {
clientAuth.trustDevice() clientAuth.trustDevice()
} }

View file

@ -3,6 +3,8 @@ package com.x8bit.bitwarden.data.auth.datasource.sdk.di
import com.bitwarden.sdk.Client import com.bitwarden.sdk.Client
import com.x8bit.bitwarden.data.auth.datasource.sdk.AuthSdkSource import com.x8bit.bitwarden.data.auth.datasource.sdk.AuthSdkSource
import com.x8bit.bitwarden.data.auth.datasource.sdk.AuthSdkSourceImpl import com.x8bit.bitwarden.data.auth.datasource.sdk.AuthSdkSourceImpl
import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager
import com.x8bit.bitwarden.data.vault.datasource.sdk.BitwardenFeatureFlagManager
import dagger.Module import dagger.Module
import dagger.Provides import dagger.Provides
import dagger.hilt.InstallIn import dagger.hilt.InstallIn
@ -20,8 +22,12 @@ object AuthSdkModule {
@Singleton @Singleton
fun provideAuthSdkSource( fun provideAuthSdkSource(
client: Client, client: Client,
featureFlagManager: BitwardenFeatureFlagManager,
dispatcherManager: DispatcherManager,
): AuthSdkSource = AuthSdkSourceImpl( ): AuthSdkSource = AuthSdkSourceImpl(
clientAuth = client.auth(), clientAuth = client.auth(),
clientPlatform = client.platform(), clientPlatform = client.platform(),
featureFlagManager = featureFlagManager,
dispatcherManager = dispatcherManager,
) )
} }

View file

@ -11,7 +11,7 @@ interface SdkClientManager {
* Returns the cached [Client] instance for the given [userId], otherwise creates and caches * Returns the cached [Client] instance for the given [userId], otherwise creates and caches
* a new one and returns it. * a new one and returns it.
*/ */
fun getOrCreateClient(userId: String): Client suspend fun getOrCreateClient(userId: String): Client
/** /**
* Clears any resources from the [Client] associated with the given [userId] and removes it * Clears any resources from the [Client] associated with the given [userId] and removes it

View file

@ -1,16 +1,21 @@
package com.x8bit.bitwarden.data.platform.manager package com.x8bit.bitwarden.data.platform.manager
import com.bitwarden.sdk.Client import com.bitwarden.sdk.Client
import com.x8bit.bitwarden.data.vault.datasource.sdk.BitwardenFeatureFlagManager
/** /**
* Primary implementation of [SdkClientManager]. * Primary implementation of [SdkClientManager].
*/ */
class SdkClientManagerImpl( class SdkClientManagerImpl(
private val clientProvider: () -> Client = { Client(null) }, private val featureFlagManager: BitwardenFeatureFlagManager,
private val clientProvider: suspend () -> Client = {
Client(null)
.apply { platform().loadFlags(featureFlagManager.featureFlags) }
},
) : SdkClientManager { ) : SdkClientManager {
private val userIdToClientMap = mutableMapOf<String, Client>() private val userIdToClientMap = mutableMapOf<String, Client>()
override fun getOrCreateClient( override suspend fun getOrCreateClient(
userId: String, userId: String,
): Client = ): Client =
userIdToClientMap.getOrPut(key = userId) { clientProvider() } userIdToClientMap.getOrPut(key = userId) { clientProvider() }

View file

@ -34,6 +34,7 @@ import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager
import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManagerImpl import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManagerImpl
import com.x8bit.bitwarden.data.platform.repository.EnvironmentRepository import com.x8bit.bitwarden.data.platform.repository.EnvironmentRepository
import com.x8bit.bitwarden.data.platform.repository.SettingsRepository import com.x8bit.bitwarden.data.platform.repository.SettingsRepository
import com.x8bit.bitwarden.data.vault.datasource.sdk.BitwardenFeatureFlagManager
import com.x8bit.bitwarden.data.vault.repository.VaultRepository import com.x8bit.bitwarden.data.vault.repository.VaultRepository
import dagger.Module import dagger.Module
import dagger.Provides import dagger.Provides
@ -97,7 +98,11 @@ object PlatformManagerModule {
@Provides @Provides
@Singleton @Singleton
fun provideSdkClientManager(): SdkClientManager = SdkClientManagerImpl() fun provideSdkClientManager(
featureFlagManager: BitwardenFeatureFlagManager,
): SdkClientManager = SdkClientManagerImpl(
featureFlagManager = featureFlagManager,
)
@Provides @Provides
@Singleton @Singleton

View file

@ -0,0 +1,11 @@
package com.x8bit.bitwarden.data.vault.datasource.sdk
/**
* Manages the available feature flags for the Bitwarden application.
*/
interface BitwardenFeatureFlagManager {
/**
* Returns a map of feature flags.
*/
val featureFlags: Map<String, Boolean>
}

View file

@ -0,0 +1,11 @@
package com.x8bit.bitwarden.data.vault.datasource.sdk
private const val CIPHER_KEY_ENCRYPTION_KEY = "enableCipherKeyEncryption"
/**
* Primary implementation of [BitwardenFeatureFlagManager].
*/
class BitwardenFeatureFlagManagerImpl : BitwardenFeatureFlagManager {
override val featureFlags: Map<String, Boolean>
get() = mapOf(CIPHER_KEY_ENCRYPTION_KEY to true)
}

View file

@ -388,7 +388,7 @@ class VaultSdkSourceImpl(
) )
} }
private fun getClient( private suspend fun getClient(
userId: String, userId: String,
): Client = sdkClientManager.getOrCreateClient(userId = userId) ): Client = sdkClientManager.getOrCreateClient(userId = userId)
} }

View file

@ -1,6 +1,8 @@
package com.x8bit.bitwarden.data.vault.datasource.sdk.di package com.x8bit.bitwarden.data.vault.datasource.sdk.di
import com.x8bit.bitwarden.data.platform.manager.SdkClientManager import com.x8bit.bitwarden.data.platform.manager.SdkClientManager
import com.x8bit.bitwarden.data.vault.datasource.sdk.BitwardenFeatureFlagManager
import com.x8bit.bitwarden.data.vault.datasource.sdk.BitwardenFeatureFlagManagerImpl
import com.x8bit.bitwarden.data.vault.datasource.sdk.VaultSdkSource import com.x8bit.bitwarden.data.vault.datasource.sdk.VaultSdkSource
import com.x8bit.bitwarden.data.vault.datasource.sdk.VaultSdkSourceImpl import com.x8bit.bitwarden.data.vault.datasource.sdk.VaultSdkSourceImpl
import dagger.Module import dagger.Module
@ -24,4 +26,9 @@ object VaultSdkModule {
VaultSdkSourceImpl( VaultSdkSourceImpl(
sdkClientManager = sdkClientManager, sdkClientManager = sdkClientManager,
) )
@Provides
@Singleton
fun providesBitwardenFeatureFlagManager(): BitwardenFeatureFlagManager =
BitwardenFeatureFlagManagerImpl()
} }

View file

@ -10,25 +10,46 @@ import com.bitwarden.crypto.TrustDeviceResponse
import com.bitwarden.sdk.ClientAuth import com.bitwarden.sdk.ClientAuth
import com.bitwarden.sdk.ClientPlatform import com.bitwarden.sdk.ClientPlatform
import com.x8bit.bitwarden.data.auth.datasource.sdk.model.PasswordStrength import com.x8bit.bitwarden.data.auth.datasource.sdk.model.PasswordStrength
import com.x8bit.bitwarden.data.platform.base.FakeDispatcherManager
import com.x8bit.bitwarden.data.platform.util.asFailure import com.x8bit.bitwarden.data.platform.util.asFailure
import com.x8bit.bitwarden.data.platform.util.asSuccess import com.x8bit.bitwarden.data.platform.util.asSuccess
import com.x8bit.bitwarden.data.vault.datasource.sdk.BitwardenFeatureFlagManager
import io.mockk.coEvery import io.mockk.coEvery
import io.mockk.coVerify import io.mockk.coVerify
import io.mockk.just
import io.mockk.mockk import io.mockk.mockk
import io.mockk.runs
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Disabled import org.junit.jupiter.api.Disabled
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
class AuthSdkSourceTest { class AuthSdkSourceTest {
private val clientAuth = mockk<ClientAuth>() private val clientAuth = mockk<ClientAuth>()
private val clientPlatform = mockk<ClientPlatform>() private val clientPlatform = mockk<ClientPlatform> {
coEvery { loadFlags(any()) } just runs
}
private val featureFlagManager = mockk<BitwardenFeatureFlagManager> {
coEvery { featureFlags } returns emptyMap()
}
private val dispatcherManager = FakeDispatcherManager()
private val authSkdSource: AuthSdkSource = AuthSdkSourceImpl( private val authSkdSource: AuthSdkSource = AuthSdkSourceImpl(
clientAuth = clientAuth, clientAuth = clientAuth,
clientPlatform = clientPlatform, clientPlatform = clientPlatform,
featureFlagManager = featureFlagManager,
dispatcherManager = dispatcherManager,
) )
@BeforeEach
fun setup() {
coVerify(exactly = 1) {
featureFlagManager.featureFlags
clientPlatform.loadFlags(any())
}
}
@Test @Test
fun `getTrustDevice with trustDevice success should return success with correct data`() = fun `getTrustDevice with trustDevice success should return success with correct data`() =
runBlocking { runBlocking {

View file

@ -2,6 +2,7 @@ package com.x8bit.bitwarden.data.platform.manager
import io.mockk.mockk import io.mockk.mockk
import io.mockk.verify import io.mockk.verify
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNotEquals import org.junit.jupiter.api.Assertions.assertNotEquals
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
@ -10,11 +11,13 @@ class SdkClientManagerTest {
private val sdkClientManager = SdkClientManagerImpl( private val sdkClientManager = SdkClientManagerImpl(
clientProvider = { mockk(relaxed = true) }, clientProvider = { mockk(relaxed = true) },
featureFlagManager = mockk(),
) )
@Suppress("MaxLineLength") @Suppress("MaxLineLength")
@Test @Test
fun `getOrCreateClient should create a new client for each userId and return a cached client for subsequent calls`() { fun `getOrCreateClient should create a new client for each userId and return a cached client for subsequent calls`() =
runTest {
val userId = "userId" val userId = "userId"
val firstClient = sdkClientManager.getOrCreateClient(userId = userId) val firstClient = sdkClientManager.getOrCreateClient(userId = userId)
@ -29,7 +32,7 @@ class SdkClientManagerTest {
} }
@Test @Test
fun `destroyClient should call close on the Client and remove it from the cache`() { fun `destroyClient should call close on the Client and remove it from the cache`() = runTest {
val userId = "userId" val userId = "userId"
val firstClient = sdkClientManager.getOrCreateClient(userId = userId) val firstClient = sdkClientManager.getOrCreateClient(userId = userId)

View file

@ -0,0 +1,18 @@
package com.x8bit.bitwarden.data.vault.datasource.sdk
import org.junit.Test
import org.junit.jupiter.api.Assertions.assertEquals
class BitwardenFeatureFlagManagerTest {
private val bitwardenFeatureFlagManager = BitwardenFeatureFlagManagerImpl()
@Test
fun `featureFlags should return set feature flags`() {
val expected = mapOf("enableCipherKeyEncryption" to true)
val actual = bitwardenFeatureFlagManager.featureFlags
assertEquals(expected, actual)
}
}

View file

@ -67,7 +67,7 @@ class VaultSdkSourceTest {
every { exporters() } returns clientExporters every { exporters() } returns clientExporters
} }
private val sdkClientManager = mockk<SdkClientManager> { private val sdkClientManager = mockk<SdkClientManager> {
every { getOrCreateClient(any()) } returns client coEvery { getOrCreateClient(any()) } returns client
every { destroyClient(any()) } just runs every { destroyClient(any()) } just runs
} }
private val vaultSdkSource: VaultSdkSource = VaultSdkSourceImpl( private val vaultSdkSource: VaultSdkSource = VaultSdkSourceImpl(
@ -102,7 +102,7 @@ class VaultSdkSourceTest {
coVerify { coVerify {
clientCrypto.derivePinKey(pin) clientCrypto.derivePinKey(pin)
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) } coVerify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
@ -125,7 +125,7 @@ class VaultSdkSourceTest {
coVerify { coVerify {
clientCrypto.derivePinUserKey(encryptedPin = encryptedPin) clientCrypto.derivePinUserKey(encryptedPin = encryptedPin)
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) } coVerify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
@ -144,7 +144,7 @@ class VaultSdkSourceTest {
coVerify { coVerify {
clientCrypto.getUserEncryptionKey() clientCrypto.getUserEncryptionKey()
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) } coVerify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
@ -192,7 +192,7 @@ class VaultSdkSourceTest {
req = mockInitCryptoRequest, req = mockInitCryptoRequest,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) } coVerify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
@ -218,7 +218,7 @@ class VaultSdkSourceTest {
req = mockInitCryptoRequest, req = mockInitCryptoRequest,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) } coVerify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
@ -245,7 +245,7 @@ class VaultSdkSourceTest {
req = mockInitCryptoRequest, req = mockInitCryptoRequest,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) } coVerify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
@ -271,7 +271,7 @@ class VaultSdkSourceTest {
req = mockInitCryptoRequest, req = mockInitCryptoRequest,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) } coVerify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
@ -297,7 +297,7 @@ class VaultSdkSourceTest {
req = mockInitCryptoRequest, req = mockInitCryptoRequest,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) } coVerify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
@ -324,7 +324,7 @@ class VaultSdkSourceTest {
req = mockInitCryptoRequest, req = mockInitCryptoRequest,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) } coVerify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
@ -350,7 +350,7 @@ class VaultSdkSourceTest {
cipherView = mockCipher, cipherView = mockCipher,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) } coVerify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
@ -376,7 +376,7 @@ class VaultSdkSourceTest {
cipher = mockCipher, cipher = mockCipher,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) } coVerify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
@ -403,7 +403,7 @@ class VaultSdkSourceTest {
ciphers = mockCiphers, ciphers = mockCiphers,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) } coVerify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
@ -429,7 +429,7 @@ class VaultSdkSourceTest {
cipher = mockCiphers, cipher = mockCiphers,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) } coVerify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
@ -455,7 +455,7 @@ class VaultSdkSourceTest {
collection = mockCollection, collection = mockCollection,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) } coVerify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
@ -482,7 +482,7 @@ class VaultSdkSourceTest {
collections = mockCollectionsList, collections = mockCollectionsList,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) } coVerify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
@ -509,7 +509,7 @@ class VaultSdkSourceTest {
send = mockSend, send = mockSend,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) } coVerify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
@ -586,7 +586,7 @@ class VaultSdkSourceTest {
send = mockSend, send = mockSend,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) } coVerify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
@ -614,7 +614,7 @@ class VaultSdkSourceTest {
folder = mockFolder, folder = mockFolder,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) } coVerify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
@ -640,7 +640,7 @@ class VaultSdkSourceTest {
folder = mockFolder, folder = mockFolder,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) } coVerify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
@ -666,7 +666,7 @@ class VaultSdkSourceTest {
folders = mockFolders, folders = mockFolders,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) } coVerify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
@ -702,7 +702,7 @@ class VaultSdkSourceTest {
decryptedFilePath = "decrypted_path", decryptedFilePath = "decrypted_path",
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) } coVerify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
@ -728,7 +728,7 @@ class VaultSdkSourceTest {
passwordHistory = mockPasswordHistoryView, passwordHistory = mockPasswordHistoryView,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) } coVerify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
@ -754,7 +754,7 @@ class VaultSdkSourceTest {
list = mockPasswordHistoryList, list = mockPasswordHistoryList,
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) } coVerify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test
@ -781,7 +781,7 @@ class VaultSdkSourceTest {
) )
} }
verify { sdkClientManager.getOrCreateClient(userId = userId) } coVerify { sdkClientManager.getOrCreateClient(userId = userId) }
} }
@Test @Test