diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 341e219283..c231c9eeaa 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -825,7 +825,7 @@ class EventsPersistenceStorageController: if state_key[0] == EventTypes.Member and self.is_mine_id(state_key[1]): membership_event_id_to_user_id_map[event_id] = state_key[1] - event_id_to_sender_map: Dict[str, str] = {} + membership_event_map: Dict[str, EventBase] = {} # In normal event persist scenarios, we should be able to find the # membership events in the `events_and_contexts` given to us but it's # possible a state reset happened which added us to the room without a @@ -834,39 +834,34 @@ class EventsPersistenceStorageController: for membership_event_id in membership_event_id_to_user_id_map.keys(): membership_event = event_map.get(membership_event_id) if membership_event: - event_id_to_sender_map[membership_event_id] = ( - membership_event.sender - ) + membership_event_map[membership_event_id] = membership_event else: missing_membership_event_ids.add(membership_event_id) # Otherwise, we need to find a couple events that we were reset to. if missing_membership_event_ids: - remaining_event_id_to_sender_map = ( - await self.main_store.get_sender_for_event_ids( - missing_membership_event_ids - ) + remaining_events = await self.main_store.get_events( + missing_membership_event_ids ) # There shouldn't be any missing events assert ( - remaining_event_id_to_sender_map.keys() - == missing_membership_event_ids - ), missing_membership_event_ids.difference( - remaining_event_id_to_sender_map.keys() - ) - event_id_to_sender_map.update(remaining_event_id_to_sender_map) + remaining_events.keys() == missing_membership_event_ids + ), missing_membership_event_ids.difference(remaining_events.keys()) + membership_event_map.update(remaining_events) membership_infos_to_insert_membership_snapshots = [ - { - "user_id": user_id, - "sender": event_id_to_sender_map[membership_event_id], - "membership_event_id": membership_event_id, - } + SlidingSyncMembershipInfo( + user_id=user_id, + sender=membership_event_map[membership_event_id].sender, + membership_event_id=membership_event_id, + membership=membership_event_map[membership_event_id].membership, + membership_event_stream_ordering=None, + ) for membership_event_id, user_id in membership_event_id_to_user_id_map.items() ] if membership_infos_to_insert_membership_snapshots: - current_state_ids_map: MutableStateMap = dict( + current_state_ids_map: MutableStateMap[str] = dict( await self.main_store.get_partial_filtered_current_state_ids( room_id, state_filter=StateFilter.from_types( @@ -987,6 +982,7 @@ class EventsPersistenceStorageController: to_delete_membership_snapshots=user_ids_to_delete_membership_snapshots, ) + # TODO: Should we put this next to the other `_get_sliding_sync_*` functions? @classmethod def _get_sliding_sync_insert_values_from_state_map( cls, state_map: StateMap[EventBase] diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 28ba64261e..f813d48519 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -34,7 +34,6 @@ from typing import ( Optional, Set, Tuple, - Union, cast, ) @@ -150,7 +149,8 @@ class SlidingSyncMembershipSnapshotSharedInsertValues( # TODO: tombstone_successor_room_id: Optional[str] -class SlidingSyncMembershipInfo(TypedDict, total=False): +@attr.s(slots=True, auto_attribs=True) +class SlidingSyncMembershipInfo: """ Values unique to each membership """ @@ -158,6 +158,9 @@ class SlidingSyncMembershipInfo(TypedDict, total=False): user_id: str sender: str membership_event_id: str + membership: str + # Sometimes we're working with events that aren't persisted yet + membership_event_stream_ordering: Optional[int] @attr.s(slots=True, auto_attribs=True) @@ -1507,8 +1510,7 @@ class PersistEventsStore: (room_id, user_id, membership_event_id, membership, event_stream_ordering {("," + ", ".join(insert_keys)) if insert_keys else ""}) VALUES ( - ?, ?, ?, - (SELECT membership FROM room_memberships WHERE event_id = ?), + ?, ?, ?, ?, (SELECT stream_ordering FROM events WHERE event_id = ?) {("," + ", ".join("?" for _ in insert_values)) if insert_values else ""} ) @@ -1522,10 +1524,10 @@ class PersistEventsStore: [ [ room_id, - membership_info["user_id"], - membership_info["membership_event_id"], - membership_info["membership_event_id"], - membership_info["membership_event_id"], + membership_info.user_id, + membership_info.membership_event_id, + membership_info.membership, + membership_info.membership_event_id, ] + list(insert_values) for membership_info in sliding_sync_table_changes.to_insert_membership_snapshots @@ -1549,6 +1551,8 @@ class PersistEventsStore: txn, {m for m in members_to_cache_bust if not self.hs.is_mine_id(m)} ) + # TODO: We can probably remove this function in favor of other stuff. + # TODO: This doesn't take into account redactions @classmethod def _get_relevant_sliding_sync_current_state_event_ids_txn( cls, txn: LoggingTransaction, room_id: str @@ -1587,10 +1591,12 @@ class PersistEventsStore: return current_state_map + # TODO: We can probably remove this function in favor of other stuff. + # TODO: Should we put this next to the other `_get_sliding_sync_*` function? @classmethod def _get_sliding_sync_insert_values_from_state_ids_map_txn( cls, txn: LoggingTransaction, state_map: StateMap[str] - ) -> Dict[str, Optional[Union[str, bool]]]: + ) -> SlidingSyncStateInsertValues: """ Fetch events in the `state_map` and extract the relevant state values needed to insert into the `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` @@ -1602,7 +1608,7 @@ class PersistEventsStore: the `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` tables. """ # Map of values to insert/update in the `sliding_sync_membership_snapshots` table - sliding_sync_insert_map: Dict[str, Optional[Union[str, bool]]] = {} + sliding_sync_insert_map: SlidingSyncStateInsertValues = {} # Fetch the raw event JSON from the database ( event_id_in_list_clause, @@ -1644,19 +1650,18 @@ class PersistEventsStore: sliding_sync_insert_map["room_name"] = room_name else: # We only expect to see events according to the - # `SLIDING_SYNC_RELEVANT_STATE_SET` which is what will - # `_get_relevant_sliding_sync_current_state_event_ids_txn()` will - # return. + # `SLIDING_SYNC_RELEVANT_STATE_SET`. raise AssertionError( f"Unexpected event (we should not be fetching extra events): ({event_type}, {state_key})" ) return sliding_sync_insert_map + # TODO: Should we put this next to the other `_get_sliding_sync_*` function? @classmethod def _get_sliding_sync_insert_values_from_stripped_state_txn( cls, txn: LoggingTransaction, unsigned_stripped_state_events: Any - ) -> Dict[str, Optional[Union[str, bool]]]: + ) -> SlidingSyncMembershipSnapshotSharedInsertValues: """ Pull out the relevant state values from the stripped state needed to insert into the `sliding_sync_membership_snapshots` tables. @@ -1666,7 +1671,7 @@ class PersistEventsStore: state values needed to insert into the `sliding_sync_membership_snapshots` tables. """ # Map of values to insert/update in the `sliding_sync_membership_snapshots` table - sliding_sync_insert_map: Dict[str, Optional[Union[str, bool]]] = {} + sliding_sync_insert_map: SlidingSyncMembershipSnapshotSharedInsertValues = {} if unsigned_stripped_state_events is not None: stripped_state_map: MutableStateMap[StrippedStateEvent] = {} diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index e77ece682f..cf24d84554 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -81,7 +81,7 @@ from synapse.storage.util.id_generators import ( MultiWriterIdGenerator, ) from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import JsonDict, StrCollection, get_domain_from_id +from synapse.types import JsonDict, get_domain_from_id from synapse.types.state import StateFilter from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred, delay_cancellation @@ -511,6 +511,8 @@ class EventsWorkerStore(SQLBaseStore): ) -> Dict[str, EventBase]: """Get events from the database + Unknown events will be omitted from the response. + Args: event_ids: The event_ids of the events to fetch @@ -1979,34 +1981,6 @@ class EventsWorkerStore(SQLBaseStore): return int(res[0]), int(res[1]) - async def get_sender_for_event_ids( - self, event_ids: StrCollection - ) -> Mapping[str, str]: - """ - Get the sender for a list of event IDs. - - Args: - event_ids: The event IDs to look up. - - Returns: - A mapping from event ID to event sender. - """ - rows = cast( - List[Tuple[str, str]], - await self.db_pool.simple_select_many_batch( - table="events", - column="event_id", - iterable=event_ids, - retcols=( - "event_id", - "sender", - ), - desc="get_sender_for_event_ids", - ), - ) - - return dict(rows) - async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]: """Retrieve the entry with the lowest expiry timestamp in the event_expiry table, or None if there's no more event to expire. diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 526bf7ea62..3b55c528ce 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -35,6 +35,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.events import ( SLIDING_SYNC_RELEVANT_STATE_SET, PersistEventsStore, + SlidingSyncMembershipSnapshotSharedInsertValues, ) from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.types import JsonDict, MutableStateMap, StateMap, StrCollection @@ -614,6 +615,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): return 0 for (room_id,) in rooms_to_update_rows: + # TODO: Handle redactions current_state_map = PersistEventsStore._get_relevant_sliding_sync_current_state_event_ids_txn( txn, room_id ) @@ -741,9 +743,9 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): ) # Map of values to insert/update in the `sliding_sync_membership_snapshots` table - sliding_sync_membership_snapshots_insert_map: Dict[ - str, Optional[Union[str, bool]] - ] = {} + sliding_sync_membership_snapshots_insert_map: ( + SlidingSyncMembershipSnapshotSharedInsertValues + ) = {} if membership == Membership.JOIN: # If we're still joined, we can pull from current state current_state_map = PersistEventsStore._get_relevant_sliding_sync_current_state_event_ids_txn( @@ -754,9 +756,12 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): # for each room assert current_state_map - sliding_sync_membership_snapshots_insert_map = PersistEventsStore._get_sliding_sync_insert_values_from_state_ids_map_txn( + state_insert_values = PersistEventsStore._get_sliding_sync_insert_values_from_state_ids_map_txn( txn, current_state_map ) + sliding_sync_membership_snapshots_insert_map.update( + state_insert_values + ) # We should have some insert values for each room, even if they are `None` assert sliding_sync_membership_snapshots_insert_map @@ -854,9 +859,12 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): ) state_map = state_by_group[state_group] - sliding_sync_membership_snapshots_insert_map = PersistEventsStore._get_sliding_sync_insert_values_from_state_ids_map_txn( + state_insert_values = PersistEventsStore._get_sliding_sync_insert_values_from_state_ids_map_txn( txn, state_map ) + sliding_sync_membership_snapshots_insert_map.update( + state_insert_values + ) # We should have some insert values for each room, even if they are `None` assert sliding_sync_membership_snapshots_insert_map @@ -922,3 +930,299 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): ) return count + + # async def _sliding_sync_membership_snapshots_backfill( + # self, progress: JsonDict, batch_size: int + # ) -> int: + # """ + # Handles backfilling the `sliding_sync_membership_snapshots` table. + # """ + # last_event_stream_ordering = progress.get( + # "last_event_stream_ordering", -(1 << 31) + # ) + + # def _find_memberships_to_update_txn( + # txn: LoggingTransaction, + # ) -> List[Tuple[str, str, str, str, str, int, bool]]: + # # Fetch the set of event IDs that we want to update + # txn.execute( + # """ + # SELECT + # c.room_id, + # c.user_id, + # e.sender + # c.event_id, + # c.membership, + # c.event_stream_ordering, + # e.outlier + # FROM local_current_membership as c + # INNER JOIN events AS e USING (event_id) + # WHERE event_stream_ordering > ? + # ORDER BY event_stream_ordering ASC + # LIMIT ? + # """, + # (last_event_stream_ordering, batch_size), + # ) + + # memberships_to_update_rows = cast( + # List[Tuple[str, str, str, str, str, int, bool]], txn.fetchall() + # ) + + # return memberships_to_update_rows + + # memberships_to_update_rows = await self.db_pool.runInteraction( + # "sliding_sync_membership_snapshots_backfill._find_memberships_to_update_txn", + # _find_memberships_to_update_txn, + # ) + + # if not memberships_to_update_rows: + # await self.db_pool.updates._end_background_update( + # _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BACKFILL + # ) + + # store = self.hs.get_storage_controllers().main + + # def _find_previous_membership_txn( + # txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int + # ) -> Tuple[str, str]: + # # Find the previous invite/knock event before the leave event + # txn.execute( + # """ + # SELECT event_id, membership + # FROM room_memberships + # WHERE + # room_id = ? + # AND user_id = ? + # AND event_stream_ordering < ? + # ORDER BY event_stream_ordering DESC + # LIMIT 1 + # """, + # ( + # room_id, + # user_id, + # stream_ordering, + # ), + # ) + # row = txn.fetchone() + + # # We should see a corresponding previous invite/knock event + # assert row is not None + # event_id, membership = row + + # return event_id, membership + + # # Map from (room_id, user_id) to ... + # to_insert_membership_snapshots: Dict[ + # Tuple[str, str], SlidingSyncMembershipSnapshotSharedInsertValues + # ] = {} + # to_insert_membership_infos: Dict[Tuple[str, str], SlidingSyncMembershipInfo] = ( + # {} + # ) + # for ( + # room_id, + # user_id, + # sender, + # membership_event_id, + # membership, + # membership_event_stream_ordering, + # is_outlier, + # ) in memberships_to_update_rows: + # # We don't know how to handle `membership` values other than these. The + # # code below would need to be updated. + # assert membership in ( + # Membership.JOIN, + # Membership.INVITE, + # Membership.KNOCK, + # Membership.LEAVE, + # Membership.BAN, + # ) + + # # Map of values to insert/update in the `sliding_sync_membership_snapshots` table + # sliding_sync_membership_snapshots_insert_map: ( + # SlidingSyncMembershipSnapshotSharedInsertValues + # ) = {} + # if membership == Membership.JOIN: + # # If we're still joined, we can pull from current state. + # current_state_ids_map: StateMap[str] = ( + # await store.get_partial_filtered_current_state_ids( + # room_id, + # state_filter=StateFilter.from_types( + # SLIDING_SYNC_RELEVANT_STATE_SET + # ), + # ) + # ) + # # We're iterating over rooms that we are joined to so they should + # # have `current_state_events` and we should have some current state + # # for each room + # assert current_state_ids_map + + # fetched_events = await store.get_events(current_state_ids_map.values()) + + # current_state_map: StateMap[EventBase] = { + # state_key: fetched_events[event_id] + # for state_key, event_id in current_state_ids_map.items() + # } + + # state_insert_values = EventsPersistenceStorageController._get_sliding_sync_insert_values_from_state_map( + # current_state_map + # ) + # sliding_sync_membership_snapshots_insert_map.update(state_insert_values) + # # We should have some insert values for each room, even if they are `None` + # assert sliding_sync_membership_snapshots_insert_map + + # # We have current state to work from + # sliding_sync_membership_snapshots_insert_map["has_known_state"] = True + # elif membership in (Membership.INVITE, Membership.KNOCK) or ( + # membership == Membership.LEAVE and is_outlier + # ): + # invite_or_knock_event_id = membership_event_id + # invite_or_knock_membership = membership + + # # If the event is an `out_of_band_membership` (special case of + # # `outlier`), we never had historical state so we have to pull from + # # the stripped state on the previous invite/knock event. This gives + # # us a consistent view of the room state regardless of your + # # membership (i.e. the room shouldn't disappear if your using the + # # `is_encrypted` filter and you leave). + # if membership == Membership.LEAVE and is_outlier: + # invite_or_knock_event_id, invite_or_knock_membership = ( + # await self.db_pool.runInteraction( + # "sliding_sync_membership_snapshots_backfill._find_previous_membership", + # _find_previous_membership_txn, + # room_id, + # user_id, + # membership_event_stream_ordering, + # ) + # ) + + # # Pull from the stripped state on the invite/knock event + # invite_or_knock_event = await store.get_event(invite_or_knock_event_id) + + # raw_stripped_state_events = None + # if invite_or_knock_membership == Membership.INVITE: + # invite_room_state = invite_or_knock_event.unsigned.get( + # "invite_room_state" + # ) + # raw_stripped_state_events = invite_room_state + # elif invite_or_knock_membership == Membership.KNOCK: + # knock_room_state = invite_or_knock_event.unsigned.get( + # "knock_room_state" + # ) + # raw_stripped_state_events = knock_room_state + + # sliding_sync_membership_snapshots_insert_map = await self.db_pool.runInteraction( + # "sliding_sync_membership_snapshots_backfill._get_sliding_sync_insert_values_from_stripped_state_txn", + # PersistEventsStore._get_sliding_sync_insert_values_from_stripped_state_txn, + # raw_stripped_state_events, + # ) + + # # We should have some insert values for each room, even if no + # # stripped state is on the event because we still want to record + # # that we have no known state + # assert sliding_sync_membership_snapshots_insert_map + # elif membership in (Membership.LEAVE, Membership.BAN): + # # Pull from historical state + # state_group = await store._get_state_group_for_event( + # membership_event_id + # ) + # # We should know the state for the event + # assert state_group is not None + + # state_by_group = await self.db_pool.runInteraction( + # "sliding_sync_membership_snapshots_backfill._get_state_groups_from_groups_txn", + # self._get_state_groups_from_groups_txn, + # groups=[state_group], + # state_filter=StateFilter.from_types( + # SLIDING_SYNC_RELEVANT_STATE_SET + # ), + # ) + # state_ids_map = state_by_group[state_group] + + # fetched_events = await store.get_events(state_ids_map.values()) + + # state_map: StateMap[EventBase] = { + # state_key: fetched_events[event_id] + # for state_key, event_id in state_ids_map.items() + # } + + # state_insert_values = EventsPersistenceStorageController._get_sliding_sync_insert_values_from_state_map( + # state_map + # ) + # sliding_sync_membership_snapshots_insert_map.update(state_insert_values) + # # We should have some insert values for each room, even if they are `None` + # assert sliding_sync_membership_snapshots_insert_map + + # # We have historical state to work from + # sliding_sync_membership_snapshots_insert_map["has_known_state"] = True + # else: + # assert_never(membership) + + # to_insert_membership_snapshots[(room_id, user_id)] = ( + # sliding_sync_membership_snapshots_insert_map + # ) + # to_insert_membership_infos[(room_id, user_id)] = SlidingSyncMembershipInfo( + # user_id=user_id, + # sender=sender, + # membership_event_id=membership_event_id, + # membership=membership, + # membership_event_stream_ordering=membership_event_stream_ordering, + # ) + + # def _backfill_table_txn(txn: LoggingTransaction) -> None: + # for key, insert_map in to_insert_membership_snapshots.items(): + # room_id, user_id = key + # membership_info = to_insert_membership_infos[key] + # membership_event_id = membership_info.membership_event_id + # membership = membership_info.membership + # membership_event_stream_ordering = ( + # membership_info.membership_event_stream_ordering + # ) + + # # Pulling keys/values separately is safe and will produce congruent + # # lists + # insert_keys = insert_map.keys() + # insert_values = insert_map.values() + # # We don't need to do anything `ON CONFLICT` because we never partially + # # insert/update the snapshots + # txn.execute( + # f""" + # INSERT INTO sliding_sync_membership_snapshots + # (room_id, user_id, membership_event_id, membership, event_stream_ordering + # {("," + ", ".join(insert_keys)) if insert_keys else ""}) + # VALUES ( + # ?, ?, ?, ?, ?, + # {("," + ", ".join("?" for _ in insert_values)) if insert_values else ""} + # ) + # ON CONFLICT (room_id, user_id) + # DO NOTHING + # """, + # [ + # room_id, + # user_id, + # membership_event_id, + # membership, + # membership_event_stream_ordering, + # ] + # + list(insert_values), + # ) + + # await self.db_pool.runInteraction( + # "sliding_sync_membership_snapshots_backfill", _backfill_table_txn + # ) + + # # Update the progress + # ( + # _room_id, + # _user_id, + # _sender, + # _membership_event_id, + # _membership, + # membership_event_stream_ordering, + # _is_outlier, + # ) = memberships_to_update_rows[-1] + # await self.db_pool.updates._background_update_progress( + # _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BACKFILL, + # {"last_event_stream_ordering": membership_event_stream_ordering}, + # ) + + # return len(memberships_to_update_rows)