diff --git a/changelog.d/17216.misc b/changelog.d/17216.misc new file mode 100644 index 0000000000..bd55eeaa33 --- /dev/null +++ b/changelog.d/17216.misc @@ -0,0 +1 @@ +Improve performance of calculating device lists changes in `/sync`. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 412bee2b76..0432d97109 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -159,20 +159,32 @@ class DeviceWorkerHandler: @cancellable async def get_device_changes_in_shared_rooms( - self, user_id: str, room_ids: StrCollection, from_token: StreamToken + self, + user_id: str, + room_ids: StrCollection, + from_token: StreamToken, + now_token: Optional[StreamToken] = None, ) -> Set[str]: """Get the set of users whose devices have changed who share a room with the given user. """ + now_device_lists_key = self.store.get_device_stream_token() + if now_token: + now_device_lists_key = now_token.device_list_key + changed_users = await self.store.get_device_list_changes_in_rooms( - room_ids, from_token.device_list_key + room_ids, + from_token.device_list_key, + now_device_lists_key, ) if changed_users is not None: # We also check if the given user has changed their device. If # they're in no rooms then the above query won't include them. changed = await self.store.get_users_whose_devices_changed( - from_token.device_list_key, [user_id] + from_token.device_list_key, + [user_id], + to_key=now_device_lists_key, ) changed_users.update(changed) return changed_users @@ -190,7 +202,9 @@ class DeviceWorkerHandler: tracked_users.add(user_id) changed = await self.store.get_users_whose_devices_changed( - from_token.device_list_key, tracked_users + from_token.device_list_key, + tracked_users, + to_key=now_device_lists_key, ) return changed diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index d3d40e8682..b7917a99d6 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1886,38 +1886,14 @@ class SyncHandler: # Step 1a, check for changes in devices of users we share a room # with - # - # We do this in two different ways depending on what we have cached. - # If we already have a list of all the user that have changed since - # the last sync then it's likely more efficient to compare the rooms - # they're in with the rooms the syncing user is in. - # - # If we don't have that info cached then we get all the users that - # share a room with our user and check if those users have changed. - cache_result = self.store.get_cached_device_list_changes( - since_token.device_list_key - ) - if cache_result.hit: - changed_users = cache_result.entities - - result = await self.store.get_rooms_for_users(changed_users) - - for changed_user_id, entries in result.items(): - # Check if the changed user shares any rooms with the user, - # or if the changed user is the syncing user (as we always - # want to include device list updates of their own devices). - if user_id == changed_user_id or any( - rid in joined_room_ids for rid in entries - ): - users_that_have_changed.add(changed_user_id) - else: - users_that_have_changed = ( - await self._device_handler.get_device_changes_in_shared_rooms( - user_id, - sync_result_builder.joined_room_ids, - from_token=since_token, - ) + users_that_have_changed = ( + await self._device_handler.get_device_changes_in_shared_rooms( + user_id, + sync_result_builder.joined_room_ids, + from_token=since_token, + now_token=sync_result_builder.now_token, ) + ) # Step 1b, check for newly joined rooms for room_id in newly_joined_rooms: diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 5e5387fdcb..cff88a87ec 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -112,6 +112,14 @@ class ReplicationDataHandler: token: stream token for this batch of rows rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ + all_room_ids: Set[str] = set() + if stream_name == DeviceListsStream.NAME: + prev_token = self.store.get_device_stream_token() + all_room_ids = await self.store.get_all_device_list_changes( + prev_token, token + ) + self.store.device_lists_in_rooms_have_changed(all_room_ids, token) + self.store.process_replication_rows(stream_name, instance_name, token, rows) # NOTE: this must be called after process_replication_rows to ensure any # cache invalidations are first handled before any stream ID advances. @@ -146,12 +154,6 @@ class ReplicationDataHandler: StreamKeyType.TO_DEVICE, token, users=entities ) elif stream_name == DeviceListsStream.NAME: - all_room_ids: Set[str] = set() - for row in rows: - if row.entity.startswith("@") and not row.is_signature: - room_ids = await self.store.get_rooms_for_user(row.entity) - all_room_ids.update(room_ids) - # `all_room_ids` can be large, so let's wake up those streams in batches for batched_room_ids in batch_iter(all_room_ids, 100): self.notifier.on_new_event( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 7ba3bc0d06..4f723d8da1 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -70,10 +70,7 @@ from synapse.types import ( from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache -from synapse.util.caches.stream_change_cache import ( - AllEntitiesChangedResult, - StreamChangeCache, -) +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -132,6 +129,20 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): prefilled_cache=device_list_prefill, ) + device_list_room_prefill, min_device_list_room_id = self.db_pool.get_cache_dict( + db_conn, + "device_lists_changes_in_room", + entity_column="room_id", + stream_column="stream_id", + max_value=device_list_max, + limit=10000, + ) + self._device_list_room_stream_cache = StreamChangeCache( + "DeviceListRoomStreamChangeCache", + min_device_list_room_id, + prefilled_cache=device_list_room_prefill, + ) + ( user_signature_stream_prefill, user_signature_stream_list_id, @@ -211,6 +222,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): row.entity, token ) + def device_lists_in_rooms_have_changed( + self, room_ids: StrCollection, token: int + ) -> None: + "Record that device lists have changed in rooms" + for room_id in room_ids: + self._device_list_room_stream_cache.entity_has_changed(room_id, token) + def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() @@ -834,16 +852,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ) return {device[0]: db_to_json(device[1]) for device in devices} - def get_cached_device_list_changes( - self, - from_key: int, - ) -> AllEntitiesChangedResult: - """Get set of users whose devices have changed since `from_key`, or None - if that information is not in our cache. - """ - - return self._device_list_stream_cache.get_all_entities_changed(from_key) - @cancellable async def get_all_devices_changed( self, @@ -1459,7 +1467,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): @cancellable async def get_device_list_changes_in_rooms( - self, room_ids: Collection[str], from_id: int + self, room_ids: Collection[str], from_id: int, to_id: int ) -> Optional[Set[str]]: """Return the set of users whose devices have changed in the given rooms since the given stream ID. @@ -1475,9 +1483,15 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): if min_stream_id > from_id: return None + changed_room_ids = self._device_list_room_stream_cache.get_entities_changed( + room_ids, from_id + ) + if not changed_room_ids: + return set() + sql = """ SELECT DISTINCT user_id FROM device_lists_changes_in_room - WHERE {clause} AND stream_id >= ? + WHERE {clause} AND stream_id > ? AND stream_id <= ? """ def _get_device_list_changes_in_rooms_txn( @@ -1489,11 +1503,12 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): return {user_id for user_id, in txn} changes = set() - for chunk in batch_iter(room_ids, 1000): + for chunk in batch_iter(changed_room_ids, 1000): clause, args = make_in_list_sql_clause( self.database_engine, "room_id", chunk ) args.append(from_id) + args.append(to_id) changes |= await self.db_pool.runInteraction( "get_device_list_changes_in_rooms", @@ -1504,6 +1519,34 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): return changes + async def get_all_device_list_changes(self, from_id: int, to_id: int) -> Set[str]: + """Return the set of rooms where devices have changed since the given + stream ID. + + Will raise an exception if the given stream ID is too old. + """ + + min_stream_id = await self._get_min_device_lists_changes_in_room() + + if min_stream_id > from_id: + raise Exception("stream ID is too old") + + sql = """ + SELECT DISTINCT room_id FROM device_lists_changes_in_room + WHERE stream_id > ? AND stream_id <= ? + """ + + def _get_all_device_list_changes_txn( + txn: LoggingTransaction, + ) -> Set[str]: + txn.execute(sql, (from_id, to_id)) + return {room_id for room_id, in txn} + + return await self.db_pool.runInteraction( + "get_all_device_list_changes", + _get_all_device_list_changes_txn, + ) + async def get_device_list_changes_in_room( self, room_id: str, min_stream_id: int ) -> Collection[Tuple[str, str]]: @@ -1964,8 +2007,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): async def add_device_change_to_streams( self, user_id: str, - device_ids: Collection[str], - room_ids: Collection[str], + device_ids: StrCollection, + room_ids: StrCollection, ) -> Optional[int]: """Persist that a user's devices have been updated, and which hosts (if any) should be poked. @@ -2147,8 +2190,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self, txn: LoggingTransaction, user_id: str, - device_ids: Iterable[str], - room_ids: Collection[str], + device_ids: StrCollection, + room_ids: StrCollection, stream_ids: List[int], context: Dict[str, str], ) -> None: @@ -2186,6 +2229,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ], ) + txn.call_after( + self.device_lists_in_rooms_have_changed, room_ids, max(stream_ids) + ) + async def get_uncoverted_outbound_room_pokes( self, start_stream_id: int, start_room_id: str, limit: int = 10 ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: