Merge branch 'erikj/device_list_sync_perf' into matrix-org-hotfixes

This commit is contained in:
Olivier 'reivilibre 2024-05-18 15:21:52 +01:00
commit 233e25e193
5 changed files with 102 additions and 62 deletions

1
changelog.d/17216.misc Normal file
View file

@ -0,0 +1 @@
Improve performance of calculating device lists changes in `/sync`.

View file

@ -159,20 +159,32 @@ class DeviceWorkerHandler:
@cancellable @cancellable
async def get_device_changes_in_shared_rooms( async def get_device_changes_in_shared_rooms(
self, user_id: str, room_ids: StrCollection, from_token: StreamToken self,
user_id: str,
room_ids: StrCollection,
from_token: StreamToken,
now_token: Optional[StreamToken] = None,
) -> Set[str]: ) -> Set[str]:
"""Get the set of users whose devices have changed who share a room with """Get the set of users whose devices have changed who share a room with
the given user. the given user.
""" """
now_device_lists_key = self.store.get_device_stream_token()
if now_token:
now_device_lists_key = now_token.device_list_key
changed_users = await self.store.get_device_list_changes_in_rooms( changed_users = await self.store.get_device_list_changes_in_rooms(
room_ids, from_token.device_list_key room_ids,
from_token.device_list_key,
now_device_lists_key,
) )
if changed_users is not None: if changed_users is not None:
# We also check if the given user has changed their device. If # We also check if the given user has changed their device. If
# they're in no rooms then the above query won't include them. # they're in no rooms then the above query won't include them.
changed = await self.store.get_users_whose_devices_changed( changed = await self.store.get_users_whose_devices_changed(
from_token.device_list_key, [user_id] from_token.device_list_key,
[user_id],
to_key=now_device_lists_key,
) )
changed_users.update(changed) changed_users.update(changed)
return changed_users return changed_users
@ -190,7 +202,9 @@ class DeviceWorkerHandler:
tracked_users.add(user_id) tracked_users.add(user_id)
changed = await self.store.get_users_whose_devices_changed( changed = await self.store.get_users_whose_devices_changed(
from_token.device_list_key, tracked_users from_token.device_list_key,
tracked_users,
to_key=now_device_lists_key,
) )
return changed return changed

View file

@ -1886,38 +1886,14 @@ class SyncHandler:
# Step 1a, check for changes in devices of users we share a room # Step 1a, check for changes in devices of users we share a room
# with # with
# users_that_have_changed = (
# We do this in two different ways depending on what we have cached. await self._device_handler.get_device_changes_in_shared_rooms(
# If we already have a list of all the user that have changed since user_id,
# the last sync then it's likely more efficient to compare the rooms sync_result_builder.joined_room_ids,
# they're in with the rooms the syncing user is in. from_token=since_token,
# now_token=sync_result_builder.now_token,
# If we don't have that info cached then we get all the users that
# share a room with our user and check if those users have changed.
cache_result = self.store.get_cached_device_list_changes(
since_token.device_list_key
)
if cache_result.hit:
changed_users = cache_result.entities
result = await self.store.get_rooms_for_users(changed_users)
for changed_user_id, entries in result.items():
# Check if the changed user shares any rooms with the user,
# or if the changed user is the syncing user (as we always
# want to include device list updates of their own devices).
if user_id == changed_user_id or any(
rid in joined_room_ids for rid in entries
):
users_that_have_changed.add(changed_user_id)
else:
users_that_have_changed = (
await self._device_handler.get_device_changes_in_shared_rooms(
user_id,
sync_result_builder.joined_room_ids,
from_token=since_token,
)
) )
)
# Step 1b, check for newly joined rooms # Step 1b, check for newly joined rooms
for room_id in newly_joined_rooms: for room_id in newly_joined_rooms:

View file

@ -112,6 +112,14 @@ class ReplicationDataHandler:
token: stream token for this batch of rows token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
""" """
all_room_ids: Set[str] = set()
if stream_name == DeviceListsStream.NAME:
prev_token = self.store.get_device_stream_token()
all_room_ids = await self.store.get_all_device_list_changes(
prev_token, token
)
self.store.device_lists_in_rooms_have_changed(all_room_ids, token)
self.store.process_replication_rows(stream_name, instance_name, token, rows) self.store.process_replication_rows(stream_name, instance_name, token, rows)
# NOTE: this must be called after process_replication_rows to ensure any # NOTE: this must be called after process_replication_rows to ensure any
# cache invalidations are first handled before any stream ID advances. # cache invalidations are first handled before any stream ID advances.
@ -146,12 +154,6 @@ class ReplicationDataHandler:
StreamKeyType.TO_DEVICE, token, users=entities StreamKeyType.TO_DEVICE, token, users=entities
) )
elif stream_name == DeviceListsStream.NAME: elif stream_name == DeviceListsStream.NAME:
all_room_ids: Set[str] = set()
for row in rows:
if row.entity.startswith("@") and not row.is_signature:
room_ids = await self.store.get_rooms_for_user(row.entity)
all_room_ids.update(room_ids)
# `all_room_ids` can be large, so let's wake up those streams in batches # `all_room_ids` can be large, so let's wake up those streams in batches
for batched_room_ids in batch_iter(all_room_ids, 100): for batched_room_ids in batch_iter(all_room_ids, 100):
self.notifier.on_new_event( self.notifier.on_new_event(

View file

@ -70,10 +70,7 @@ from synapse.types import (
from synapse.util import json_decoder, json_encoder from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.stream_change_cache import ( from synapse.util.caches.stream_change_cache import StreamChangeCache
AllEntitiesChangedResult,
StreamChangeCache,
)
from synapse.util.cancellation import cancellable from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr from synapse.util.stringutils import shortstr
@ -132,6 +129,20 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
prefilled_cache=device_list_prefill, prefilled_cache=device_list_prefill,
) )
device_list_room_prefill, min_device_list_room_id = self.db_pool.get_cache_dict(
db_conn,
"device_lists_changes_in_room",
entity_column="room_id",
stream_column="stream_id",
max_value=device_list_max,
limit=10000,
)
self._device_list_room_stream_cache = StreamChangeCache(
"DeviceListRoomStreamChangeCache",
min_device_list_room_id,
prefilled_cache=device_list_room_prefill,
)
( (
user_signature_stream_prefill, user_signature_stream_prefill,
user_signature_stream_list_id, user_signature_stream_list_id,
@ -211,6 +222,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
row.entity, token row.entity, token
) )
def device_lists_in_rooms_have_changed(
self, room_ids: StrCollection, token: int
) -> None:
"Record that device lists have changed in rooms"
for room_id in room_ids:
self._device_list_room_stream_cache.entity_has_changed(room_id, token)
def get_device_stream_token(self) -> int: def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token() return self._device_list_id_gen.get_current_token()
@ -834,16 +852,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
) )
return {device[0]: db_to_json(device[1]) for device in devices} return {device[0]: db_to_json(device[1]) for device in devices}
def get_cached_device_list_changes(
self,
from_key: int,
) -> AllEntitiesChangedResult:
"""Get set of users whose devices have changed since `from_key`, or None
if that information is not in our cache.
"""
return self._device_list_stream_cache.get_all_entities_changed(from_key)
@cancellable @cancellable
async def get_all_devices_changed( async def get_all_devices_changed(
self, self,
@ -1459,7 +1467,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
@cancellable @cancellable
async def get_device_list_changes_in_rooms( async def get_device_list_changes_in_rooms(
self, room_ids: Collection[str], from_id: int self, room_ids: Collection[str], from_id: int, to_id: int
) -> Optional[Set[str]]: ) -> Optional[Set[str]]:
"""Return the set of users whose devices have changed in the given rooms """Return the set of users whose devices have changed in the given rooms
since the given stream ID. since the given stream ID.
@ -1475,9 +1483,15 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
if min_stream_id > from_id: if min_stream_id > from_id:
return None return None
changed_room_ids = self._device_list_room_stream_cache.get_entities_changed(
room_ids, from_id
)
if not changed_room_ids:
return set()
sql = """ sql = """
SELECT DISTINCT user_id FROM device_lists_changes_in_room SELECT DISTINCT user_id FROM device_lists_changes_in_room
WHERE {clause} AND stream_id >= ? WHERE {clause} AND stream_id > ? AND stream_id <= ?
""" """
def _get_device_list_changes_in_rooms_txn( def _get_device_list_changes_in_rooms_txn(
@ -1489,11 +1503,12 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return {user_id for user_id, in txn} return {user_id for user_id, in txn}
changes = set() changes = set()
for chunk in batch_iter(room_ids, 1000): for chunk in batch_iter(changed_room_ids, 1000):
clause, args = make_in_list_sql_clause( clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", chunk self.database_engine, "room_id", chunk
) )
args.append(from_id) args.append(from_id)
args.append(to_id)
changes |= await self.db_pool.runInteraction( changes |= await self.db_pool.runInteraction(
"get_device_list_changes_in_rooms", "get_device_list_changes_in_rooms",
@ -1504,6 +1519,34 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return changes return changes
async def get_all_device_list_changes(self, from_id: int, to_id: int) -> Set[str]:
"""Return the set of rooms where devices have changed since the given
stream ID.
Will raise an exception if the given stream ID is too old.
"""
min_stream_id = await self._get_min_device_lists_changes_in_room()
if min_stream_id > from_id:
raise Exception("stream ID is too old")
sql = """
SELECT DISTINCT room_id FROM device_lists_changes_in_room
WHERE stream_id > ? AND stream_id <= ?
"""
def _get_all_device_list_changes_txn(
txn: LoggingTransaction,
) -> Set[str]:
txn.execute(sql, (from_id, to_id))
return {room_id for room_id, in txn}
return await self.db_pool.runInteraction(
"get_all_device_list_changes",
_get_all_device_list_changes_txn,
)
async def get_device_list_changes_in_room( async def get_device_list_changes_in_room(
self, room_id: str, min_stream_id: int self, room_id: str, min_stream_id: int
) -> Collection[Tuple[str, str]]: ) -> Collection[Tuple[str, str]]:
@ -1964,8 +2007,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
async def add_device_change_to_streams( async def add_device_change_to_streams(
self, self,
user_id: str, user_id: str,
device_ids: Collection[str], device_ids: StrCollection,
room_ids: Collection[str], room_ids: StrCollection,
) -> Optional[int]: ) -> Optional[int]:
"""Persist that a user's devices have been updated, and which hosts """Persist that a user's devices have been updated, and which hosts
(if any) should be poked. (if any) should be poked.
@ -2147,8 +2190,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self, self,
txn: LoggingTransaction, txn: LoggingTransaction,
user_id: str, user_id: str,
device_ids: Iterable[str], device_ids: StrCollection,
room_ids: Collection[str], room_ids: StrCollection,
stream_ids: List[int], stream_ids: List[int],
context: Dict[str, str], context: Dict[str, str],
) -> None: ) -> None:
@ -2186,6 +2229,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
], ],
) )
txn.call_after(
self.device_lists_in_rooms_have_changed, room_ids, max(stream_ids)
)
async def get_uncoverted_outbound_room_pokes( async def get_uncoverted_outbound_room_pokes(
self, start_stream_id: int, start_room_id: str, limit: int = 10 self, start_stream_id: int, start_room_id: str, limit: int = 10
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: