From 878d427d2006d7b71486df2fb1cf36ed0df5d11c Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 10 Oct 2024 15:06:15 -0500 Subject: [PATCH] Notify when one-time keys are claimed Fix https://github.com/element-hq/synapse/issues/17474 --- synapse/handlers/e2e_keys.py | 6 +- synapse/notifier.py | 36 +++ .../sliding_sync/test_extension_e2ee.py | 212 +++++++++++++++--- 3 files changed, 224 insertions(+), 30 deletions(-) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index f78e66ad0a..9ff1960e11 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -70,6 +70,7 @@ class E2eKeysHandler: self.is_mine = hs.is_mine self.clock = hs.get_clock() self._worker_lock_handler = hs.get_worker_locks_handler() + self._notifier = hs.get_notifier() federation_registry = hs.get_federation_registry() @@ -615,7 +616,7 @@ class E2eKeysHandler: 3. Attempt to fetch fallback keys from the database. Args: - local_query: An iterable of tuples of (user ID, device ID, algorithm). + local_query: An iterable of tuples of (user ID, device ID, algorithm, number of keys to claim). always_include_fallback_keys: True to always include fallback keys. Returns: @@ -629,6 +630,7 @@ class E2eKeysHandler: ] otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query) + self._notifier.notify_one_time_keys_claimed(otk_results.keys()) # If the application services have not provided any keys via the C-S # API, query it directly for one-time keys. @@ -639,6 +641,7 @@ class E2eKeysHandler: appservice_results, not_found, ) = await self._appservice_handler.claim_e2e_one_time_keys(not_found) + self._notifier.notify_one_time_keys_claimed(appservice_results.keys()) else: appservice_results = {} @@ -693,6 +696,7 @@ class E2eKeysHandler: # For each user that does not have a one-time keys available, see if # there is a fallback key. fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query) + self._notifier.notify_one_time_keys_claimed(fallback_results.keys()) # Return the results in order, each item from the input query should # only appear once in the combined list. diff --git a/synapse/notifier.py b/synapse/notifier.py index 88f531182a..05e7f5b594 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -590,6 +590,8 @@ class Notifier: without waking up any of the normal user event streams""" self.notify_replication() + # FIXME: Should this be renamed to `wait_for_activity`? This listens for new events + # and when one-time keys are claimed which doesn't correspond to an event. async def wait_for_events( self, user_id: str, @@ -900,6 +902,40 @@ class Notifier: for cb in self._lock_released_callback: cb(instance_name, lock_name, lock_key) + def notify_one_time_keys_claimed( + self, + users: Union[StrCollection, Collection[UserID]], + ) -> None: + """ + Used by handlers to inform the notifier that a one-time key has been + claimed + """ + # Bail early if there is nothing to do + if not users: + return + + time_now_ms = self.clock.time_msec() + current_token = self.event_sources.get_current_token() + listeners: List["Deferred[StreamToken]"] = [] + for user in users: + user_stream = self.user_to_user_stream.get(str(user)) + if user_stream is None: + continue + + try: + listeners.extend( + user_stream.update_and_fetch_deferreds(current_token, time_now_ms) + ) + except Exception: + logger.exception("Failed to notify listener") + + # We resolve all these deferreds in one go so that we only need to + # call `PreserveLoggingContext` once, as it has a bunch of overhead + # (to calculate performance stats) + with PreserveLoggingContext(): + for listener in listeners: + listener.callback(current_token) + @attr.s(auto_attribs=True) class ReplicationNotifier: diff --git a/tests/rest/client/sliding_sync/test_extension_e2ee.py b/tests/rest/client/sliding_sync/test_extension_e2ee.py index 7ce6592d8f..808b0ed230 100644 --- a/tests/rest/client/sliding_sync/test_extension_e2ee.py +++ b/tests/rest/client/sliding_sync/test_extension_e2ee.py @@ -11,9 +11,11 @@ # See the GNU Affero General Public License for more details: # . # +import enum import logging -from parameterized import parameterized_class +from parameterized import parameterized, parameterized_class +from typing_extensions import assert_never from twisted.test.proto_helpers import MemoryReactor @@ -29,6 +31,12 @@ from tests.server import TimedOutException logger = logging.getLogger(__name__) +class E2eeBumpAction(enum.Enum): + device_lists = enum.auto() + one_time_keys = enum.auto() + fallback_one_time_keys = enum.auto() + + # FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the # foreground update for # `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by @@ -147,19 +155,38 @@ class SlidingSyncE2eeExtensionTestCase(SlidingSyncBase): [], ) - def test_wait_for_new_data(self) -> None: + @parameterized.expand( + [ + ( + "bump_device_lists", + E2eeBumpAction.device_lists, + ), + ( + "bump_one_time_keys", + E2eeBumpAction.one_time_keys, + ), + ( + "bump_fallback_one_time_keys", + E2eeBumpAction.fallback_one_time_keys, + ), + ] + ) + def test_wait_for_new_data( + self, test_description: str, bump_action: E2eeBumpAction + ) -> None: """ Test to make sure that the Sliding Sync request waits for new data to arrive. (Only applies to incremental syncs with a `timeout` specified) """ + test_device_id = "TESTDEVICE" user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") + user1_tok = self.login(user1_id, "pass", device_id=test_device_id) user2_id = self.register_user("user2", "pass") user2_tok = self.login(user2_id, "pass") - test_device_id = "TESTDEVICE" + other_user_test_device_id = "OTHERUSERTESTDEVICE" user3_id = self.register_user("user3", "pass") - user3_tok = self.login(user3_id, "pass", device_id=test_device_id) + user3_tok = self.login(user3_id, "pass", device_id=other_user_test_device_id) room_id = self.helper.create_room_as(user2_id, tok=user2_tok) self.helper.join(room_id, user1_id, tok=user1_tok) @@ -186,34 +213,161 @@ class SlidingSyncE2eeExtensionTestCase(SlidingSyncBase): # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)` with self.assertRaises(TimedOutException): channel.await_result(timeout_ms=5000) - # Bump the device lists to trigger new results - # Have user3 update their device list - device_update_channel = self.make_request( - "PUT", - f"/devices/{test_device_id}", - { - "display_name": "New Device Name", - }, - access_token=user3_tok, - ) - self.assertEqual( - device_update_channel.code, 200, device_update_channel.json_body - ) + + if bump_action == E2eeBumpAction.device_lists: + # Bump the device lists to trigger new results + # Have user3 update their device list + device_update_channel = self.make_request( + "PUT", + f"/devices/{other_user_test_device_id}", + { + "display_name": "New Device Name", + }, + access_token=user3_tok, + ) + self.assertEqual( + device_update_channel.code, 200, device_update_channel.json_body + ) + elif bump_action == E2eeBumpAction.one_time_keys: + # Upload one time keys for the user/device + keys: JsonDict = { + "alg1:k1": "key1", + "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}}, + "alg2:k3": {"key": "key3"}, + } + upload_keys_response = self.get_success( + self.e2e_keys_handler.upload_keys_for_user( + user1_id, test_device_id, {"one_time_keys": keys} + ) + ) + self.assertDictEqual( + upload_keys_response, + { + "one_time_key_counts": { + "alg1": 1, + "alg2": 2, + # Note that "signed_curve25519" is always returned in key count responses + # regardless of whether we uploaded any keys for it. This is necessary until + # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed. + # + # Also related: + # https://github.com/element-hq/element-android/issues/3725 and + # https://github.com/matrix-org/synapse/issues/10456 + "signed_curve25519": 0, + } + }, + ) + + # Claim one of those new keys + self.get_success( + self.e2e_keys_handler.claim_local_one_time_keys( + local_query=[(user1_id, test_device_id, "alg2", 1)], + always_include_fallback_keys=False, + ) + ) + + # TODO: We should have a way to let clients differentiate between the states of: + # * no change in OTK count since the provided since token + # * the server has zero OTKs left for this device + # Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298 + elif bump_action == E2eeBumpAction.fallback_one_time_keys: + # Upload a fallback key for the user/device + self.get_success( + self.e2e_keys_handler.upload_keys_for_user( + user1_id, + test_device_id, + { + "fallback_keys": { + "alg1:k1": "fallback_key1", + "alg2:k2": "fallback_key2", + } + }, + ) + ) + # We should now have an unused alg1 and alg2 key + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(user1_id, test_device_id) + ) + self.assertIncludes( + set(fallback_res), + {"alg1", "alg2"}, + exact=True, + message=str(fallback_res), + ) + + # Claim one of those fallback keys + self.get_success( + self.e2e_keys_handler.claim_local_one_time_keys( + local_query=[(user1_id, test_device_id, "alg1", 1)], + always_include_fallback_keys=False, + ) + ) + + # TODO: We should have a way to let clients differentiate between the states of: + # * no change in OTK count since the provided since token + # * the server has zero OTKs left for this device + # Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298 + else: + assert_never(bump_action) + # Should respond before the 10 second timeout channel.await_result(timeout_ms=3000) self.assertEqual(channel.code, 200, channel.json_body) - # We should see the device list update - self.assertEqual( - channel.json_body["extensions"]["e2ee"] - .get("device_lists", {}) - .get("changed"), - [user3_id], - ) - self.assertEqual( - channel.json_body["extensions"]["e2ee"].get("device_lists", {}).get("left"), - [], - ) + # Check for the new data + # + if bump_action == E2eeBumpAction.device_lists: + # We should see the device list update + self.assertEqual( + channel.json_body["extensions"]["e2ee"] + .get("device_lists", {}) + .get("changed"), + [user3_id], + ) + self.assertEqual( + channel.json_body["extensions"]["e2ee"] + .get("device_lists", {}) + .get("left"), + [], + ) + elif bump_action == E2eeBumpAction.one_time_keys: + # We should see the one-time key count change + self.assertEqual( + channel.json_body["extensions"]["e2ee"].get( + "device_one_time_keys_count" + ), + { + "alg1": 1, + # Note: This changed from 2 -> 1 since we claimed one of them + "alg2": 1, + # Note that "signed_curve25519" is always returned in key count responses + # regardless of whether we uploaded any keys for it. This is necessary until + # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed. + # + # Also related: + # https://github.com/element-hq/element-android/issues/3725 and + # https://github.com/matrix-org/synapse/issues/10456 + "signed_curve25519": 0, + }, + ) + elif bump_action == E2eeBumpAction.fallback_one_time_keys: + # Check for the unused fallback key types + self.assertIncludes( + set( + channel.json_body["extensions"]["e2ee"].get( + "device_unused_fallback_key_types", [] + ) + ), + {"alg2"}, + exact=True, + message=str( + channel.json_body["extensions"]["e2ee"].get( + "device_unused_fallback_key_types", + ) + ), + ) + else: + assert_never(bump_action) def test_wait_for_new_data_timeout(self) -> None: """