From 54456d3d4f2fdcd8f246ea7f438a6e73ed9f4e5a Mon Sep 17 00:00:00 2001
From: David Perez <david@livefront.com>
Date: Thu, 23 Jan 2025 16:51:51 -0600
Subject: [PATCH] PM-15804, PM-17130: Add logic to monitor when the screen on
 state to ensure the vault locks properly (#4618)

---
 .../vault/manager/VaultLockManagerImpl.kt     | 74 ++++++++++++++++---
 .../vault/manager/di/VaultManagerModule.kt    |  4 +
 .../vault/manager/VaultLockManagerTest.kt     | 65 ++++++++++++++++
 3 files changed, 131 insertions(+), 12 deletions(-)

diff --git a/app/src/main/java/com/x8bit/bitwarden/data/vault/manager/VaultLockManagerImpl.kt b/app/src/main/java/com/x8bit/bitwarden/data/vault/manager/VaultLockManagerImpl.kt
index 9fadf13d1..f32eeecb5 100644
--- a/app/src/main/java/com/x8bit/bitwarden/data/vault/manager/VaultLockManagerImpl.kt
+++ b/app/src/main/java/com/x8bit/bitwarden/data/vault/manager/VaultLockManagerImpl.kt
@@ -1,5 +1,9 @@
 package com.x8bit.bitwarden.data.vault.manager
 
+import android.content.BroadcastReceiver
+import android.content.Context
+import android.content.Intent
+import android.content.IntentFilter
 import com.bitwarden.core.InitOrgCryptoRequest
 import com.bitwarden.core.InitUserCryptoMethod
 import com.bitwarden.core.InitUserCryptoRequest
@@ -50,6 +54,8 @@ import kotlinx.coroutines.flow.onCompletion
 import kotlinx.coroutines.flow.onEach
 import kotlinx.coroutines.flow.update
 import kotlinx.coroutines.launch
+import java.time.Clock
+import java.util.concurrent.ConcurrentHashMap
 import kotlin.time.Duration.Companion.minutes
 
 /**
@@ -62,6 +68,7 @@ private const val MAXIMUM_INVALID_UNLOCK_ATTEMPTS = 5
  */
 @Suppress("TooManyFunctions", "LongParameterList")
 class VaultLockManagerImpl(
+    private val clock: Clock,
     private val authDiskSource: AuthDiskSource,
     private val authSdkSource: AuthSdkSource,
     private val vaultSdkSource: VaultSdkSource,
@@ -70,13 +77,15 @@ class VaultLockManagerImpl(
     private val userLogoutManager: UserLogoutManager,
     private val trustedDeviceManager: TrustedDeviceManager,
     dispatcherManager: DispatcherManager,
+    context: Context,
 ) : VaultLockManager {
     private val unconfinedScope = CoroutineScope(dispatcherManager.unconfined)
 
     /**
-     * This [Map] tracks all active timeout [Job]s that are running using the user ID as the key.
+     * This [Map] tracks all active timeout [Job]s that are running and their associated data using
+     * the user ID as the key.
      */
-    private val userIdTimerJobMap = mutableMapOf<String, Job>()
+    private val userIdTimerJobMap: MutableMap<String, TimeoutJobData> = ConcurrentHashMap()
 
     private val activeUserId: String? get() = authDiskSource.userState?.activeUserId
 
@@ -96,6 +105,10 @@ class VaultLockManagerImpl(
         observeUserSwitchingChanges()
         observeVaultTimeoutChanges()
         observeUserLogoutResults()
+        context.registerReceiver(
+            ScreenStateBroadcastReceiver(),
+            IntentFilter(Intent.ACTION_SCREEN_ON),
+        )
     }
 
     override fun isVaultUnlocked(userId: String): Boolean =
@@ -363,7 +376,7 @@ class VaultLockManagerImpl(
 
     private fun handleOnForeground() {
         val userId = activeUserId ?: return
-        userIdTimerJobMap[userId]?.cancel()
+        userIdTimerJobMap.remove(key = userId)?.job?.cancel()
     }
 
     private fun observeUserSwitchingChanges() {
@@ -459,7 +472,7 @@ class VaultLockManagerImpl(
         currentActiveUserId: String,
     ) {
         // Make sure to clear the now-active user's timeout job.
-        userIdTimerJobMap[currentActiveUserId]?.cancel()
+        userIdTimerJobMap.remove(key = currentActiveUserId)?.job?.cancel()
         // Check if the user's timeout action should be performed as we switch away.
         checkForVaultTimeout(
             userId = previousActiveUserId,
@@ -529,7 +542,7 @@ class VaultLockManagerImpl(
                         handleTimeoutActionWithDelay(
                             userId = userId,
                             vaultTimeoutAction = vaultTimeoutAction,
-                            delayInMs = vaultTimeout
+                            delayMs = vaultTimeout
                                 .vaultTimeoutInMinutes
                                 ?.minutes
                                 ?.inWholeMilliseconds
@@ -542,20 +555,26 @@ class VaultLockManagerImpl(
     }
 
     /**
-     * Performs the [VaultTimeoutAction] for the given [userId] after the [delayInMs] has passed.
+     * Performs the [VaultTimeoutAction] for the given [userId] after the [delayMs] has passed.
      *
      * @see handleTimeoutAction
      */
     private fun handleTimeoutActionWithDelay(
         userId: String,
         vaultTimeoutAction: VaultTimeoutAction,
-        delayInMs: Long,
+        delayMs: Long,
     ) {
-        userIdTimerJobMap[userId]?.cancel()
-        userIdTimerJobMap[userId] = unconfinedScope.launch {
-            delay(timeMillis = delayInMs)
-            handleTimeoutAction(userId = userId, vaultTimeoutAction = vaultTimeoutAction)
-        }
+        userIdTimerJobMap.remove(key = userId)?.job?.cancel()
+        userIdTimerJobMap[userId] = TimeoutJobData(
+            job = unconfinedScope.launch {
+                delay(timeMillis = delayMs)
+                userIdTimerJobMap.remove(key = userId)
+                handleTimeoutAction(userId = userId, vaultTimeoutAction = vaultTimeoutAction)
+            },
+            vaultTimeoutAction = vaultTimeoutAction,
+            startTimeMs = clock.millis(),
+            durationMs = delayMs,
+        )
     }
 
     /**
@@ -601,6 +620,37 @@ class VaultLockManagerImpl(
         return (accounts.find { it.userId == userId }?.isLoggedIn) == false
     }
 
+    /**
+     * A custom [BroadcastReceiver] that listens for when the screen is powered on and restarts the
+     * vault timeout jobs to ensure they complete at the correct time.
+     *
+     * This is necessary because the [delay] function in a coroutine will not keep accurate time
+     * when the screen is off. We do not cancel the job when the screen is off, this allows the
+     * job to complete as-soon-as possible if the screen is powered off for an extended period.
+     */
+    private inner class ScreenStateBroadcastReceiver : BroadcastReceiver() {
+        override fun onReceive(context: Context, intent: Intent) {
+            userIdTimerJobMap.map { (userId, data) ->
+                handleTimeoutActionWithDelay(
+                    userId = userId,
+                    vaultTimeoutAction = data.vaultTimeoutAction,
+                    delayMs = data.durationMs - (clock.millis() - data.startTimeMs)
+                        .coerceAtLeast(minimumValue = 0L),
+                )
+            }
+        }
+    }
+
+    /**
+     * A wrapper class containing all relevant data concerning a timeout action [Job].
+     */
+    private data class TimeoutJobData(
+        val job: Job,
+        val vaultTimeoutAction: VaultTimeoutAction,
+        val startTimeMs: Long,
+        val durationMs: Long,
+    )
+
     /**
      * Helper sealed class which denotes the reason to check the vault timeout.
      */
diff --git a/app/src/main/java/com/x8bit/bitwarden/data/vault/manager/di/VaultManagerModule.kt b/app/src/main/java/com/x8bit/bitwarden/data/vault/manager/di/VaultManagerModule.kt
index 68a363953..e894c3439 100644
--- a/app/src/main/java/com/x8bit/bitwarden/data/vault/manager/di/VaultManagerModule.kt
+++ b/app/src/main/java/com/x8bit/bitwarden/data/vault/manager/di/VaultManagerModule.kt
@@ -71,6 +71,8 @@ object VaultManagerModule {
     @Provides
     @Singleton
     fun provideVaultLockManager(
+        @ApplicationContext context: Context,
+        clock: Clock,
         authDiskSource: AuthDiskSource,
         authSdkSource: AuthSdkSource,
         vaultSdkSource: VaultSdkSource,
@@ -81,6 +83,8 @@ object VaultManagerModule {
         trustedDeviceManager: TrustedDeviceManager,
     ): VaultLockManager =
         VaultLockManagerImpl(
+            context = context,
+            clock = clock,
             authDiskSource = authDiskSource,
             authSdkSource = authSdkSource,
             vaultSdkSource = vaultSdkSource,
diff --git a/app/src/test/java/com/x8bit/bitwarden/data/vault/manager/VaultLockManagerTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/vault/manager/VaultLockManagerTest.kt
index 668a03bb7..e83827cb1 100644
--- a/app/src/test/java/com/x8bit/bitwarden/data/vault/manager/VaultLockManagerTest.kt
+++ b/app/src/test/java/com/x8bit/bitwarden/data/vault/manager/VaultLockManagerTest.kt
@@ -1,5 +1,8 @@
 package com.x8bit.bitwarden.data.vault.manager
 
+import android.content.BroadcastReceiver
+import android.content.Context
+import android.content.Intent
 import app.cash.turbine.test
 import com.bitwarden.core.InitOrgCryptoRequest
 import com.bitwarden.core.InitUserCryptoMethod
@@ -36,6 +39,7 @@ import io.mockk.every
 import io.mockk.just
 import io.mockk.mockk
 import io.mockk.runs
+import io.mockk.slot
 import io.mockk.verify
 import kotlinx.coroutines.ExperimentalCoroutinesApi
 import kotlinx.coroutines.async
@@ -49,11 +53,18 @@ import org.junit.jupiter.api.Assertions.assertEquals
 import org.junit.jupiter.api.Assertions.assertFalse
 import org.junit.jupiter.api.Assertions.assertTrue
 import org.junit.jupiter.api.Test
+import java.time.Clock
+import java.time.Instant
+import java.time.ZoneOffset
 import java.time.ZonedDateTime
 
 @OptIn(ExperimentalCoroutinesApi::class)
 @Suppress("LargeClass")
 class VaultLockManagerTest {
+    private val broadcastReceiver = slot<BroadcastReceiver>()
+    private val context: Context = mockk {
+        every { registerReceiver(capture(broadcastReceiver), any()) } returns null
+    }
     private val fakeAuthDiskSource = FakeAuthDiskSource()
     private val fakeAppStateManager = FakeAppStateManager()
     private val authSdkSource: AuthSdkSource = mockk {
@@ -88,6 +99,8 @@ class VaultLockManagerTest {
     private val fakeDispatcherManager = FakeDispatcherManager(unconfined = testDispatcher)
 
     private val vaultLockManager: VaultLockManager = VaultLockManagerImpl(
+        context = context,
+        clock = FIXED_CLOCK,
         authDiskSource = fakeAuthDiskSource,
         authSdkSource = authSdkSource,
         vaultSdkSource = vaultSdkSource,
@@ -98,6 +111,53 @@ class VaultLockManagerTest {
         dispatcherManager = fakeDispatcherManager,
     )
 
+    @Test
+    fun `broadcast receiver should be registered on initialization`() {
+        verify(exactly = 1) {
+            context.registerReceiver(any(), any())
+        }
+    }
+
+    @Test
+    fun `broadcast intent should reset active job`() {
+        setAccountTokens()
+        fakeAuthDiskSource.userState = MOCK_USER_STATE
+
+        // Setup state as unlocked
+        mutableVaultTimeoutStateFlow.value = VaultTimeout.OneMinute
+        mutableVaultTimeoutActionStateFlow.value = VaultTimeoutAction.LOCK
+        fakeAppStateManager.appForegroundState = AppForegroundState.FOREGROUNDED
+        verifyUnlockedVaultBlocking(userId = USER_ID)
+        assertTrue(vaultLockManager.isVaultUnlocked(USER_ID))
+
+        // Background the app
+        fakeAppStateManager.appForegroundState = AppForegroundState.BACKGROUNDED
+
+        // Advance by 30 seconds (half of what is required to lock the app)
+        testDispatcher.scheduler.advanceTimeBy(delayTimeMillis = 30 * 1000L)
+
+        // Still unlocked
+        assertTrue(vaultLockManager.isVaultUnlocked(USER_ID))
+
+        // Receive the screen on event
+        broadcastReceiver.captured.onReceive(context, Intent())
+
+        // Still unlocked
+        assertTrue(vaultLockManager.isVaultUnlocked(USER_ID))
+
+        // Because the test clock is fixed, this should mean that we need to advance the clock a
+        // full minute to get the vault to lock.
+        testDispatcher.scheduler.advanceTimeBy(delayTimeMillis = 30 * 1000L)
+
+        // Still unlocked
+        assertTrue(vaultLockManager.isVaultUnlocked(USER_ID))
+
+        testDispatcher.scheduler.advanceTimeBy(delayTimeMillis = 31 * 1000L)
+
+        // Finally locked
+        assertFalse(vaultLockManager.isVaultUnlocked(USER_ID))
+    }
+
     @Test
     fun `vaultStateEventFlow should emit Locked event when vault state changes to locked`() =
         runTest {
@@ -1587,6 +1647,11 @@ class VaultLockManagerTest {
     }
 }
 
+private val FIXED_CLOCK: Clock = Clock.fixed(
+    Instant.parse("2023-10-27T12:00:00Z"),
+    ZoneOffset.UTC,
+)
+
 private const val USER_ID = "mockId-1"
 
 private val MOCK_TIMEOUTS = VaultTimeout.Type.entries.map {