Check isAuthenticated status within PushManager (#600)

This commit is contained in:
Sean Weiser 2024-01-13 11:38:45 -06:00 committed by Álison Fernandes
parent 3d75867a15
commit 8efd9d2c8a
4 changed files with 155 additions and 59 deletions

View file

@ -94,6 +94,11 @@ class PushManagerImpl @Inject constructor(
override val syncSendUpsertFlow: SharedFlow<SyncSendUpsertData>
get() = mutableSyncSendUpsertSharedFlow.asSharedFlow()
private val activeUserId: String?
get() = authDiskSource.userState?.activeUserId
private val isLoggedIn: Boolean
get() = authDiskSource.userState?.activeAccount?.isLoggedIn == true
init {
authDiskSource
.userStateFlow
@ -113,7 +118,7 @@ class PushManagerImpl @Inject constructor(
if (authDiskSource.uniqueAppId == notification.contextId) return
val userId = authDiskSource.userState?.activeUserId
val userId = activeUserId
when (val type = notification.notificationType) {
NotificationType.AUTH_REQUEST,
@ -130,7 +135,7 @@ class PushManagerImpl @Inject constructor(
}
NotificationType.LOG_OUT -> {
if (userId == null) return
if (!isLoggedIn) return
mutableLogoutSharedFlow.tryEmit(Unit)
}
@ -139,7 +144,7 @@ class PushManagerImpl @Inject constructor(
-> {
val payload: NotificationPayload.SyncCipherNotification =
json.decodeFromJsonElement(notification.payload)
if (!payload.userMatchesNotification(userId)) return
if (!isLoggedIn || !payload.userMatchesNotification(userId)) return
mutableSyncCipherUpsertSharedFlow.tryEmit(
SyncCipherUpsertData(
cipherId = payload.id,
@ -154,7 +159,7 @@ class PushManagerImpl @Inject constructor(
-> {
val payload: NotificationPayload.SyncCipherNotification =
json.decodeFromJsonElement(notification.payload)
if (!payload.userMatchesNotification(userId)) return
if (!isLoggedIn || !payload.userMatchesNotification(userId)) return
mutableSyncCipherDeleteSharedFlow.tryEmit(
SyncCipherDeleteData(payload.id),
)
@ -173,7 +178,7 @@ class PushManagerImpl @Inject constructor(
-> {
val payload: NotificationPayload.SyncFolderNotification =
json.decodeFromJsonElement(notification.payload)
if (!payload.userMatchesNotification(userId)) return
if (!isLoggedIn || !payload.userMatchesNotification(userId)) return
mutableSyncFolderUpsertSharedFlow.tryEmit(
SyncFolderUpsertData(
folderId = payload.id,
@ -186,7 +191,7 @@ class PushManagerImpl @Inject constructor(
NotificationType.SYNC_FOLDER_DELETE -> {
val payload: NotificationPayload.SyncFolderNotification =
json.decodeFromJsonElement(notification.payload)
if (!payload.userMatchesNotification(userId)) return
if (!isLoggedIn || !payload.userMatchesNotification(userId)) return
mutableSyncFolderDeleteSharedFlow.tryEmit(
SyncFolderDeleteData(payload.id),
@ -194,7 +199,7 @@ class PushManagerImpl @Inject constructor(
}
NotificationType.SYNC_ORG_KEYS -> {
if (userId == null) return
if (!isLoggedIn) return
mutableSyncOrgKeysSharedFlow.tryEmit(Unit)
}
@ -203,7 +208,7 @@ class PushManagerImpl @Inject constructor(
-> {
val payload: NotificationPayload.SyncSendNotification =
json.decodeFromJsonElement(notification.payload)
if (!payload.userMatchesNotification(userId)) return
if (!isLoggedIn || !payload.userMatchesNotification(userId)) return
mutableSyncSendUpsertSharedFlow.tryEmit(
SyncSendUpsertData(
sendId = payload.id,
@ -216,7 +221,7 @@ class PushManagerImpl @Inject constructor(
NotificationType.SYNC_SEND_DELETE -> {
val payload: NotificationPayload.SyncSendNotification =
json.decodeFromJsonElement(notification.payload)
if (!payload.userMatchesNotification(userId)) return
if (!isLoggedIn || !payload.userMatchesNotification(userId)) return
mutableSyncSendDeleteSharedFlow.tryEmit(
SyncSendDeleteData(payload.id),
)
@ -226,7 +231,9 @@ class PushManagerImpl @Inject constructor(
override fun registerPushTokenIfNecessary(token: String) {
pushDiskSource.registeredPushToken = token
val userId = authDiskSource.userState?.activeUserId ?: return
if (!isLoggedIn) return
val userId = activeUserId ?: return
ioScope.launch {
registerPushTokenIfNecessaryInternal(
userId = userId,
@ -235,8 +242,10 @@ class PushManagerImpl @Inject constructor(
}
}
@Suppress("ReturnCount")
override fun registerStoredPushTokenIfNecessary() {
val userId = authDiskSource.userState?.activeUserId ?: return
if (!isLoggedIn) return
val userId = activeUserId ?: return
// If the last registered token is from less than a day before, skip this for now
val lastRegistration = pushDiskSource.getLastPushTokenRegistrationDate(userId)?.toInstant()

View file

@ -1,14 +1,19 @@
package com.x8bit.bitwarden.data.platform.manager.di
import android.content.Context
import com.x8bit.bitwarden.data.auth.datasource.disk.AuthDiskSource
import com.x8bit.bitwarden.data.auth.repository.AuthRepository
import com.x8bit.bitwarden.data.platform.datasource.disk.PushDiskSource
import com.x8bit.bitwarden.data.platform.datasource.network.authenticator.RefreshAuthenticator
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthTokenInterceptor
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors
import com.x8bit.bitwarden.data.platform.datasource.network.service.PushService
import com.x8bit.bitwarden.data.platform.manager.AppForegroundManager
import com.x8bit.bitwarden.data.platform.manager.AppForegroundManagerImpl
import com.x8bit.bitwarden.data.platform.manager.NetworkConfigManager
import com.x8bit.bitwarden.data.platform.manager.NetworkConfigManagerImpl
import com.x8bit.bitwarden.data.platform.manager.PushManager
import com.x8bit.bitwarden.data.platform.manager.PushManagerImpl
import com.x8bit.bitwarden.data.platform.manager.SdkClientManager
import com.x8bit.bitwarden.data.platform.manager.SdkClientManagerImpl
import com.x8bit.bitwarden.data.platform.manager.clipboard.BitwardenClipboardManager
@ -21,6 +26,7 @@ import dagger.Provides
import dagger.hilt.InstallIn
import dagger.hilt.android.qualifiers.ApplicationContext
import dagger.hilt.components.SingletonComponent
import kotlinx.serialization.json.Json
import java.time.Clock
import javax.inject.Singleton
@ -72,4 +78,22 @@ object PlatformManagerModule {
refreshAuthenticator = refreshAuthenticator,
dispatcherManager = dispatcherManager,
)
@Provides
@Singleton
fun providePushManager(
authDiskSource: AuthDiskSource,
pushDiskSource: PushDiskSource,
pushService: PushService,
dispatcherManager: DispatcherManager,
clock: Clock,
json: Json,
): PushManager = PushManagerImpl(
authDiskSource = authDiskSource,
pushDiskSource = pushDiskSource,
pushService = pushService,
dispatcherManager = dispatcherManager,
clock = clock,
json = json,
)
}

View file

@ -1,41 +0,0 @@
package com.x8bit.bitwarden.data.platform.manager.di
import com.x8bit.bitwarden.data.auth.datasource.disk.AuthDiskSource
import com.x8bit.bitwarden.data.platform.datasource.disk.PushDiskSource
import com.x8bit.bitwarden.data.platform.datasource.network.service.PushService
import com.x8bit.bitwarden.data.platform.manager.PushManager
import com.x8bit.bitwarden.data.platform.manager.PushManagerImpl
import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager
import dagger.Module
import dagger.Provides
import dagger.hilt.InstallIn
import dagger.hilt.components.SingletonComponent
import kotlinx.serialization.json.Json
import java.time.Clock
import javax.inject.Singleton
/**
* Provides repositories in the push package.
*/
@Module
@InstallIn(SingletonComponent::class)
object PushManagerModule {
@Provides
@Singleton
fun providePushManager(
authDiskSource: AuthDiskSource,
pushDiskSource: PushDiskSource,
pushService: PushService,
dispatcherManager: DispatcherManager,
clock: Clock,
json: Json,
): PushManager = PushManagerImpl(
authDiskSource = authDiskSource,
pushDiskSource = pushDiskSource,
pushService = pushService,
dispatcherManager = dispatcherManager,
clock = clock,
json = json,
)
}

View file

@ -2,6 +2,7 @@ package com.x8bit.bitwarden.data.platform.manager
import app.cash.turbine.test
import com.x8bit.bitwarden.data.auth.datasource.disk.AuthDiskSource
import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson
import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson
import com.x8bit.bitwarden.data.auth.datasource.disk.util.FakeAuthDiskSource
import com.x8bit.bitwarden.data.platform.base.FakeDispatcherManager
@ -23,6 +24,7 @@ import com.x8bit.bitwarden.data.platform.util.asFailure
import com.x8bit.bitwarden.data.platform.util.asSuccess
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.every
import io.mockk.mockk
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Assertions.assertEquals
@ -103,11 +105,75 @@ class PushManagerTest {
}
@Nested
inner class MatchingUser {
inner class LoggedOutUserState {
@BeforeEach
fun setUp() {
val userId = "any user ID"
val account = mockk<AccountJson> {
every { isLoggedIn } returns false
}
authDiskSource.userState = UserStateJson(userId, mapOf(userId to account))
}
@Test
fun `onMessageReceived logout does nothing`() = runTest {
pushManager.logoutFlow.test {
pushManager.onMessageReceived(LOGOUT_NOTIFICATION_JSON)
expectNoEvents()
}
}
@Test
fun `onMessageReceived sync ciphers emits to fullSyncFlow`() = runTest {
pushManager.fullSyncFlow.test {
pushManager.onMessageReceived(SYNC_CIPHERS_NOTIFICATION_JSON)
assertEquals(
Unit,
awaitItem(),
)
}
}
@Test
fun `onMessageReceived sync org keys does nothing`() = runTest {
pushManager.fullSyncFlow.test {
pushManager.onMessageReceived(SYNC_ORG_KEYS_NOTIFICATION_JSON)
expectNoEvents()
}
}
@Test
fun `onMessageReceived sync settings emits to fullSyncFlow`() = runTest {
pushManager.fullSyncFlow.test {
pushManager.onMessageReceived(SYNC_SETTINGS_NOTIFICATION_JSON)
assertEquals(
Unit,
awaitItem(),
)
}
}
@Test
fun `onMessageReceived sync vault emits to fullSyncFlow`() = runTest {
pushManager.fullSyncFlow.test {
pushManager.onMessageReceived(SYNC_VAULT_NOTIFICATION_JSON)
assertEquals(
Unit,
awaitItem(),
)
}
}
}
@Nested
inner class MatchingLoggedInUser {
@BeforeEach
fun setUp() {
val userId = "078966a2-93c2-4618-ae2a-0a2394c88d37"
authDiskSource.userState = UserStateJson(userId, mapOf(userId to mockk()))
val account = mockk<AccountJson> {
every { isLoggedIn } returns true
}
authDiskSource.userState = UserStateJson(userId, mapOf(userId to account))
}
@Test
@ -264,11 +330,14 @@ class PushManagerTest {
}
@Nested
inner class NonMatchingUser {
inner class NonMatchingLoggedInUser {
@BeforeEach
fun setUp() {
val userId = "bad user ID"
authDiskSource.userState = UserStateJson(userId, mapOf(userId to mockk()))
val account = mockk<AccountJson> {
every { isLoggedIn } returns true
}
authDiskSource.userState = UserStateJson(userId, mapOf(userId to account))
}
@Test
@ -405,7 +474,10 @@ class PushManagerTest {
@BeforeEach
fun setUp() {
val userId = "any user ID"
authDiskSource.userState = UserStateJson(userId, mapOf(userId to mockk()))
val account = mockk<AccountJson> {
every { isLoggedIn } returns true
}
authDiskSource.userState = UserStateJson(userId, mapOf(userId to account))
}
@Test
@ -467,6 +539,35 @@ class PushManagerTest {
@Nested
inner class PushNotificationRegistration {
@Nested
inner class LoggedOutUserState {
@BeforeEach
fun setUp() {
val userId = "any user ID"
val account = mockk<AccountJson> {
every { isLoggedIn } returns false
}
authDiskSource.userState = UserStateJson(userId, mapOf(userId to account))
}
@Test
fun `registerPushTokenIfNecessary should update registeredPushToken`() {
assertEquals(null, pushDiskSource.registeredPushToken)
val token = "token"
pushManager.registerPushTokenIfNecessary(token)
assertEquals(token, pushDiskSource.registeredPushToken)
}
@Test
fun `registerStoredPushTokenIfNecessary should do nothing`() {
pushManager.registerStoredPushTokenIfNecessary()
assertNull(pushDiskSource.registeredPushToken)
}
}
@Nested
inner class NullUserState {
@BeforeEach
@ -493,14 +594,17 @@ class PushManagerTest {
}
@Nested
inner class NonNullUserState {
inner class NonNullLoggedInUserState {
private val existingToken = "existingToken"
private val userId = "userId"
@BeforeEach
fun setUp() {
pushDiskSource.storeCurrentPushToken(userId, existingToken)
authDiskSource.userState = UserStateJson(userId, mapOf(userId to mockk()))
val account = mockk<AccountJson> {
every { isLoggedIn } returns true
}
authDiskSource.userState = UserStateJson(userId, mapOf(userId to account))
}
@Suppress("MaxLineLength")