From ec174d047005e4ac976311f4d3730452b2c5710f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 2 Apr 2024 15:33:56 +0100 Subject: [PATCH 1/7] Refactor chain fetching (#17044) Since these queries are duplicated in two places. --- changelog.d/17044.misc | 1 + .../databases/main/event_federation.py | 162 +++++++----------- 2 files changed, 67 insertions(+), 96 deletions(-) create mode 100644 changelog.d/17044.misc diff --git a/changelog.d/17044.misc b/changelog.d/17044.misc new file mode 100644 index 0000000000..a1439752d3 --- /dev/null +++ b/changelog.d/17044.misc @@ -0,0 +1 @@ +Refactor auth chain fetching to reduce duplication. diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 846c3f363a..fb132ef090 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -27,6 +27,7 @@ from typing import ( Collection, Dict, FrozenSet, + Generator, Iterable, List, Optional, @@ -279,64 +280,16 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # Now we look up all links for the chains we have, adding chains that # are reachable from any event. - # - # This query is structured to first get all chain IDs reachable, and - # then pull out all links from those chains. This does pull out more - # rows than is strictly necessary, however there isn't a way of - # structuring the recursive part of query to pull out the links without - # also returning large quantities of redundant data (which can make it a - # lot slower). - sql = """ - WITH RECURSIVE links(chain_id) AS ( - SELECT - DISTINCT origin_chain_id - FROM event_auth_chain_links WHERE %s - UNION - SELECT - target_chain_id - FROM event_auth_chain_links - INNER JOIN links ON (chain_id = origin_chain_id) - ) - SELECT - origin_chain_id, origin_sequence_number, - target_chain_id, target_sequence_number - FROM links - INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id) - """ # A map from chain ID to max sequence number *reachable* from any event ID. chains: Dict[int, int] = {} - - # Add all linked chains reachable from initial set of chains. - chains_to_fetch = set(event_chains.keys()) - while chains_to_fetch: - batch2 = tuple(itertools.islice(chains_to_fetch, 1000)) - chains_to_fetch.difference_update(batch2) - clause, args = make_in_list_sql_clause( - txn.database_engine, "origin_chain_id", batch2 - ) - txn.execute(sql % (clause,), args) - - links: Dict[int, List[Tuple[int, int, int]]] = {} - - for ( - origin_chain_id, - origin_sequence_number, - target_chain_id, - target_sequence_number, - ) in txn: - links.setdefault(origin_chain_id, []).append( - (origin_sequence_number, target_chain_id, target_sequence_number) - ) - + for links in self._get_chain_links(txn, set(event_chains.keys())): for chain_id in links: if chain_id not in event_chains: continue _materialize(chain_id, event_chains[chain_id], links, chains) - chains_to_fetch.difference_update(chains) - # Add the initial set of chains, excluding the sequence corresponding to # initial event. for chain_id, seq_no in event_chains.items(): @@ -380,6 +333,68 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas return results + @classmethod + def _get_chain_links( + cls, txn: LoggingTransaction, chains_to_fetch: Set[int] + ) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]: + """Fetch all auth chain links from the given set of chains, and all + links from those chains, recursively. + + Note: This may return links that are not reachable from the given + chains. + + Returns a generator that produces dicts from origin chain ID to 3-tuple + of origin sequence number, target chain ID and target sequence number. + """ + + # This query is structured to first get all chain IDs reachable, and + # then pull out all links from those chains. This does pull out more + # rows than is strictly necessary, however there isn't a way of + # structuring the recursive part of query to pull out the links without + # also returning large quantities of redundant data (which can make it a + # lot slower). + sql = """ + WITH RECURSIVE links(chain_id) AS ( + SELECT + DISTINCT origin_chain_id + FROM event_auth_chain_links WHERE %s + UNION + SELECT + target_chain_id + FROM event_auth_chain_links + INNER JOIN links ON (chain_id = origin_chain_id) + ) + SELECT + origin_chain_id, origin_sequence_number, + target_chain_id, target_sequence_number + FROM links + INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id) + """ + + while chains_to_fetch: + batch2 = tuple(itertools.islice(chains_to_fetch, 1000)) + chains_to_fetch.difference_update(batch2) + clause, args = make_in_list_sql_clause( + txn.database_engine, "origin_chain_id", batch2 + ) + txn.execute(sql % (clause,), args) + + links: Dict[int, List[Tuple[int, int, int]]] = {} + + for ( + origin_chain_id, + origin_sequence_number, + target_chain_id, + target_sequence_number, + ) in txn: + links.setdefault(origin_chain_id, []).append( + (origin_sequence_number, target_chain_id, target_sequence_number) + ) + + chains_to_fetch.difference_update(links) + + yield links + def _get_auth_chain_ids_txn( self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool ) -> Set[str]: @@ -564,53 +579,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # Now we look up all links for the chains we have, adding chains that # are reachable from any event. - # - # This query is structured to first get all chain IDs reachable, and - # then pull out all links from those chains. This does pull out more - # rows than is strictly necessary, however there isn't a way of - # structuring the recursive part of query to pull out the links without - # also returning large quantities of redundant data (which can make it a - # lot slower). - sql = """ - WITH RECURSIVE links(chain_id) AS ( - SELECT - DISTINCT origin_chain_id - FROM event_auth_chain_links WHERE %s - UNION - SELECT - target_chain_id - FROM event_auth_chain_links - INNER JOIN links ON (chain_id = origin_chain_id) - ) - SELECT - origin_chain_id, origin_sequence_number, - target_chain_id, target_sequence_number - FROM links - INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id) - """ - - # (We need to take a copy of `seen_chains` as we want to mutate it in - # the loop) - chains_to_fetch = set(seen_chains) - while chains_to_fetch: - batch2 = tuple(itertools.islice(chains_to_fetch, 1000)) - clause, args = make_in_list_sql_clause( - txn.database_engine, "origin_chain_id", batch2 - ) - txn.execute(sql % (clause,), args) - - links: Dict[int, List[Tuple[int, int, int]]] = {} - - for ( - origin_chain_id, - origin_sequence_number, - target_chain_id, - target_sequence_number, - ) in txn: - links.setdefault(origin_chain_id, []).append( - (origin_sequence_number, target_chain_id, target_sequence_number) - ) + # (We need to take a copy of `seen_chains` as the function mutates it) + for links in self._get_chain_links(txn, set(seen_chains)): for chains in set_to_chain: for chain_id in links: if chain_id not in chains: @@ -618,7 +589,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas _materialize(chain_id, chains[chain_id], links, chains) - chains_to_fetch.difference_update(chains) seen_chains.update(chains) # Now for each chain we figure out the maximum sequence number reachable From ca27b516656223150d218bdd838df302fedf838c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 2 Apr 2024 17:17:02 +0100 Subject: [PATCH 2/7] 1.104.0 --- CHANGES.md | 7 +++++++ changelog.d/17031.feature | 1 - debian/changelog | 6 ++++++ pyproject.toml | 2 +- 4 files changed, 14 insertions(+), 2 deletions(-) delete mode 100644 changelog.d/17031.feature diff --git a/CHANGES.md b/CHANGES.md index fa9af218a6..168e29f1b2 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,10 @@ +# Synapse 1.104.0 (2024-04-02) + +### Bugfixes + +- Fix regression when using OIDC provider. Introduced in v1.104.0rc1. ([\#17031](https://github.com/element-hq/synapse/issues/17031)) + + # Synapse 1.104.0rc1 (2024-03-26) ### Features diff --git a/changelog.d/17031.feature b/changelog.d/17031.feature deleted file mode 100644 index 0f28cbbcd6..0000000000 --- a/changelog.d/17031.feature +++ /dev/null @@ -1 +0,0 @@ -OIDC: try to JWT decode userinfo response if JSON parsing failed. diff --git a/debian/changelog b/debian/changelog index b915b6e2cb..28451044ab 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.104.0) stable; urgency=medium + + * New Synapse release 1.104.0. + + -- Synapse Packaging team Tue, 02 Apr 2024 17:15:45 +0100 + matrix-synapse-py3 (1.104.0~rc1) stable; urgency=medium * New Synapse release 1.104.0rc1. diff --git a/pyproject.toml b/pyproject.toml index 8369139301..9a645079c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,7 +96,7 @@ module-name = "synapse.synapse_rust" [tool.poetry] name = "matrix-synapse" -version = "1.104.0rc1" +version = "1.104.0" description = "Homeserver for the Matrix decentralised comms protocol" authors = ["Matrix.org Team and Contributors "] license = "AGPL-3.0-or-later" From 31122b71bcf29b4a034be4fc14770f4b8a45b2c5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 4 Apr 2024 11:05:40 +0100 Subject: [PATCH 3/7] Add missing index to `access_tokens` table (#17045) This was causing sequential scans when using refresh tokens. --- changelog.d/17045.misc | 1 + synapse/storage/databases/main/registration.py | 7 +++++++ 2 files changed, 8 insertions(+) create mode 100644 changelog.d/17045.misc diff --git a/changelog.d/17045.misc b/changelog.d/17045.misc new file mode 100644 index 0000000000..0d042a43ff --- /dev/null +++ b/changelog.d/17045.misc @@ -0,0 +1 @@ +Improve database performance by adding a missing index to `access_tokens.refresh_token_id`. diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index d939ade427..30a3ae3055 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -2266,6 +2266,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): ): super().__init__(database, db_conn, hs) + self.db_pool.updates.register_background_index_update( + update_name="access_tokens_refresh_token_id_idx", + index_name="access_tokens_refresh_token_id_idx", + table="access_tokens", + columns=("refresh_token_id",), + ) + self._ignore_unknown_session_error = ( hs.config.server.request_token_inhibit_3pid_errors ) From 05957ac70f5d634eafbea61bd79a9a89196507c2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 4 Apr 2024 12:47:59 +0100 Subject: [PATCH 4/7] Fix bug in `/sync` response for archived rooms (#16932) This PR fixes a very, very niche edge-case, but I've got some more work coming which will otherwise make the problem worse. The bug happens when the syncing user leaves a room, and has a sync filter which includes "left" rooms, but sets the timeline limit to 0. In that case, the state returned in the `state` section is calculated incorrectly. The fix is to pass a token corresponding to the point that the user leaves the room through to `compute_state_delta`. --- changelog.d/16932.bugfix | 1 + synapse/handlers/sync.py | 121 ++++++++++++++++++--- tests/handlers/test_sync.py | 208 +++++++++++++++++++++++++++++++++--- tests/rest/client/utils.py | 18 ++-- 4 files changed, 314 insertions(+), 34 deletions(-) create mode 100644 changelog.d/16932.bugfix diff --git a/changelog.d/16932.bugfix b/changelog.d/16932.bugfix new file mode 100644 index 0000000000..624388ea8e --- /dev/null +++ b/changelog.d/16932.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug which could cause incorrect state to be returned from `/sync` for rooms where the user has left. \ No newline at end of file diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 3aa2e2b7ba..7fcd54ac55 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -953,7 +953,7 @@ class SyncHandler: batch: TimelineBatch, sync_config: SyncConfig, since_token: Optional[StreamToken], - now_token: StreamToken, + end_token: StreamToken, full_state: bool, ) -> MutableStateMap[EventBase]: """Works out the difference in state between the end of the previous sync and @@ -964,7 +964,9 @@ class SyncHandler: batch: The timeline batch for the room that will be sent to the user. sync_config: since_token: Token of the end of the previous batch. May be `None`. - now_token: Token of the end of the current batch. + end_token: Token of the end of the current batch. Normally this will be + the same as the global "now_token", but if the user has left the room, + the point just after their leave event. full_state: Whether to force returning the full state. `lazy_load_members` still applies when `full_state` is `True`. @@ -1044,7 +1046,7 @@ class SyncHandler: room_id, sync_config.user, batch, - now_token, + end_token, members_to_fetch, timeline_state, ) @@ -1058,7 +1060,7 @@ class SyncHandler: room_id, batch, since_token, - now_token, + end_token, members_to_fetch, timeline_state, ) @@ -1130,7 +1132,7 @@ class SyncHandler: room_id: str, syncing_user: UserID, batch: TimelineBatch, - now_token: StreamToken, + end_token: StreamToken, members_to_fetch: Optional[Set[str]], timeline_state: StateMap[str], ) -> StateMap[str]: @@ -1143,7 +1145,9 @@ class SyncHandler: room_id: The room we are calculating for. syncing_user: The user that is calling `/sync`. batch: The timeline batch for the room that will be sent to the user. - now_token: Token of the end of the current batch. + end_token: Token of the end of the current batch. Normally this will be + the same as the global "now_token", but if the user has left the room, + the point just after their leave event. members_to_fetch: If lazy-loading is enabled, the memberships needed for events in the timeline. timeline_state: The contribution to the room state from state events in @@ -1202,7 +1206,7 @@ class SyncHandler: else: state_at_timeline_end = await self.get_state_at( room_id, - stream_position=now_token, + stream_position=end_token, state_filter=state_filter, await_full_state=await_full_state, ) @@ -1223,7 +1227,7 @@ class SyncHandler: room_id: str, batch: TimelineBatch, since_token: StreamToken, - now_token: StreamToken, + end_token: StreamToken, members_to_fetch: Optional[Set[str]], timeline_state: StateMap[str], ) -> StateMap[str]: @@ -1239,7 +1243,9 @@ class SyncHandler: room_id: The room we are calculating for. batch: The timeline batch for the room that will be sent to the user. since_token: Token of the end of the previous batch. - now_token: Token of the end of the current batch. + end_token: Token of the end of the current batch. Normally this will be + the same as the global "now_token", but if the user has left the room, + the point just after their leave event. members_to_fetch: If lazy-loading is enabled, the memberships needed for events in the timeline. Otherwise, `None`. timeline_state: The contribution to the room state from state events in @@ -1273,7 +1279,7 @@ class SyncHandler: # the recent events. state_at_timeline_start = await self.get_state_at( room_id, - stream_position=now_token, + stream_position=end_token, state_filter=state_filter, await_full_state=await_full_state, ) @@ -1312,7 +1318,7 @@ class SyncHandler: # the recent events. state_at_timeline_end = await self.get_state_at( room_id, - stream_position=now_token, + stream_position=end_token, state_filter=state_filter, await_full_state=await_full_state, ) @@ -2344,6 +2350,7 @@ class SyncHandler: full_state=False, since_token=since_token, upto_token=leave_token, + end_token=leave_token, out_of_band=leave_event.internal_metadata.is_out_of_band_membership(), ) ) @@ -2381,6 +2388,7 @@ class SyncHandler: full_state=False, since_token=None if newly_joined else since_token, upto_token=prev_batch_token, + end_token=now_token, ) else: entry = RoomSyncResultBuilder( @@ -2391,6 +2399,7 @@ class SyncHandler: full_state=False, since_token=since_token, upto_token=since_token, + end_token=now_token, ) room_entries.append(entry) @@ -2449,6 +2458,7 @@ class SyncHandler: full_state=True, since_token=since_token, upto_token=now_token, + end_token=now_token, ) ) elif event.membership == Membership.INVITE: @@ -2478,6 +2488,7 @@ class SyncHandler: full_state=True, since_token=since_token, upto_token=leave_token, + end_token=leave_token, ) ) @@ -2548,6 +2559,7 @@ class SyncHandler: { "since_token": since_token, "upto_token": upto_token, + "end_token": room_builder.end_token, } ) @@ -2621,7 +2633,7 @@ class SyncHandler: batch, sync_config, since_token, - now_token, + room_builder.end_token, full_state=full_state, ) else: @@ -2781,6 +2793,70 @@ def _calculate_state( e for t, e in timeline_start.items() if t[0] == EventTypes.Member ) + # Naively, we would just return the difference between the state at the start + # of the timeline (`timeline_start_ids`) and that at the end of the previous sync + # (`previous_timeline_end_ids`). However, that fails in the presence of forks in + # the DAG. + # + # For example, consider a DAG such as the following: + # + # E1 + # ↗ ↖ + # | S2 + # | ↑ + # --|------|---- + # | | + # E3 | + # ↖ / + # E4 + # + # ... and a filter that means we only return 2 events, represented by the dashed + # horizontal line. Assuming S2 was *not* included in the previous sync, we need to + # include it in the `state` section. + # + # Note that the state at the start of the timeline (E3) does not include S2. So, + # to make sure it gets included in the calculation here, we actually look at + # the state at the *end* of the timeline, and subtract any events that are present + # in the timeline. + # + # ---------- + # + # Aside 1: You may then wonder if we need to include `timeline_start` in the + # calculation. Consider a linear DAG: + # + # E1 + # ↑ + # S2 + # ↑ + # ----|------ + # | + # E3 + # ↑ + # S4 + # ↑ + # E5 + # + # ... where S2 and S4 change the same piece of state; and where we have a filter + # that returns 3 events (E3, S4, E5). We still need to tell the client about S2, + # because it might affect the display of E3. However, the state at the end of the + # timeline only tells us about S4; if we don't inspect `timeline_start` we won't + # find out about S2. + # + # (There are yet more complicated cases in which a state event is excluded from the + # timeline, but whose effect actually lands in the DAG in the *middle* of the + # timeline. We have no way to represent that in the /sync response, and we don't + # even try; it is ether omitted or plonked into `state` as if it were at the start + # of the timeline, depending on what else is in the timeline.) + # + # ---------- + # + # Aside 2: it's worth noting that `timeline_end`, as provided to us, is actually + # the state *before* the final event in the timeline. In other words: if the final + # event in the timeline is a state event, it won't be included in `timeline_end`. + # However, that doesn't matter here, because the only difference can be in that + # one piece of state, and by definition that event is in the timeline, so we + # don't need to include it in the `state` section. + state_ids = ( (timeline_end_ids | timeline_start_ids) - previous_timeline_end_ids @@ -2883,13 +2959,30 @@ class RoomSyncResultBuilder: Attributes: room_id + rtype: One of `"joined"` or `"archived"` + events: List of events to include in the room (more events may be added when generating result). + newly_joined: If the user has newly joined the room + full_state: Whether the full state should be sent in result + since_token: Earliest point to return events from, or None - upto_token: Latest point to return events from. + + upto_token: Latest point to return events from. If `events` is populated, + this is set to the token at the start of `events` + + end_token: The last point in the timeline that the client should see events + from. Normally this will be the same as the global `now_token`, but in + the case of rooms where the user has left the room, this will be the point + just after their leave event. + + This is used in the calculation of the state which is returned in `state`: + any state changes *up to* `end_token` (and not beyond!) which are not + reflected in the timeline need to be returned in `state`. + out_of_band: whether the events in the room are "out of band" events and the server isn't in the room. """ @@ -2901,5 +2994,5 @@ class RoomSyncResultBuilder: full_state: bool since_token: Optional[StreamToken] upto_token: StreamToken - + end_token: StreamToken out_of_band: bool = False diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 1b36324b8f..897c52c785 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -17,14 +17,16 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Collection, List, Optional +from typing import Collection, ContextManager, List, Optional from unittest.mock import AsyncMock, Mock, patch +from parameterized import parameterized + from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes, JoinRules from synapse.api.errors import Codes, ResourceLimitError -from synapse.api.filtering import Filtering +from synapse.api.filtering import FilterCollection, Filtering from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.events import EventBase from synapse.events.snapshot import EventContext @@ -33,7 +35,7 @@ from synapse.handlers.sync import SyncConfig, SyncResult from synapse.rest import admin from synapse.rest.client import knock, login, room from synapse.server import HomeServer -from synapse.types import UserID, create_requester +from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock import tests.unittest @@ -258,13 +260,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # Eve tries to join the room. We monkey patch the internal logic which selects # the prev_events used when creating the join event, such that the ban does not # precede the join. - mocked_get_prev_events = patch.object( - self.hs.get_datastores().main, - "get_prev_events_for_room", - new_callable=AsyncMock, - return_value=[last_room_creation_event_id], - ) - with mocked_get_prev_events: + with self._patch_get_latest_events([last_room_creation_event_id]): self.helper.join(room_id, eve, tok=eve_token) # Eve makes a second, incremental sync. @@ -288,6 +284,180 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): ) self.assertEqual(eve_initial_sync_after_join.joined, []) + def test_state_includes_changes_on_forks(self) -> None: + """State changes that happen on a fork of the DAG must be included in `state` + + Given the following DAG: + + E1 + ↗ ↖ + | S2 + | ↑ + --|------|---- + | | + E3 | + ↖ / + E4 + + ... and a filter that means we only return 2 events, represented by the dashed + horizontal line: `S2` must be included in the `state` section. + """ + alice = self.register_user("alice", "password") + alice_tok = self.login(alice, "password") + alice_requester = create_requester(alice) + room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok) + + # Do an initial sync as Alice to get a known starting point. + initial_sync_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + alice_requester, generate_sync_config(alice) + ) + ) + last_room_creation_event_id = ( + initial_sync_result.joined[0].timeline.events[-1].event_id + ) + + # Send a state event, and a regular event, both using the same prev ID + with self._patch_get_latest_events([last_room_creation_event_id]): + s2_event = self.helper.send_state(room_id, "s2", {}, tok=alice_tok)[ + "event_id" + ] + e3_event = self.helper.send(room_id, "e3", tok=alice_tok)["event_id"] + + # Send a final event, joining the two branches of the dag + e4_event = self.helper.send(room_id, "e4", tok=alice_tok)["event_id"] + + # do an incremental sync, with a filter that will ensure we only get two of + # the three new events. + incremental_sync = self.get_success( + self.sync_handler.wait_for_sync_for_user( + alice_requester, + generate_sync_config( + alice, + filter_collection=FilterCollection( + self.hs, {"room": {"timeline": {"limit": 2}}} + ), + ), + since_token=initial_sync_result.next_batch, + ) + ) + + # The state event should appear in the 'state' section of the response. + room_sync = incremental_sync.joined[0] + self.assertEqual(room_sync.room_id, room_id) + self.assertTrue(room_sync.timeline.limited) + self.assertEqual( + [e.event_id for e in room_sync.timeline.events], + [e3_event, e4_event], + ) + self.assertEqual( + [e.event_id for e in room_sync.state.values()], + [s2_event], + ) + + @parameterized.expand( + [ + (False, False), + (True, False), + (False, True), + (True, True), + ] + ) + def test_archived_rooms_do_not_include_state_after_leave( + self, initial_sync: bool, empty_timeline: bool + ) -> None: + """If the user leaves the room, state changes that happen after they leave are not returned. + + We try with both a zero and a normal timeline limit, + and we try both an initial sync and an incremental sync for both. + """ + if empty_timeline and not initial_sync: + # FIXME synapse doesn't return the room at all in this situation! + self.skipTest("Synapse does not correctly handle this case") + + # Alice creates the room, and bob joins. + alice = self.register_user("alice", "password") + alice_tok = self.login(alice, "password") + + bob = self.register_user("bob", "password") + bob_tok = self.login(bob, "password") + bob_requester = create_requester(bob) + + room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok) + self.helper.join(room_id, bob, tok=bob_tok) + + initial_sync_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + bob_requester, generate_sync_config(bob) + ) + ) + + # Alice sends a message and a state + before_message_event = self.helper.send(room_id, "before", tok=alice_tok)[ + "event_id" + ] + before_state_event = self.helper.send_state( + room_id, "test_state", {"body": "before"}, tok=alice_tok + )["event_id"] + + # Bob leaves + leave_event = self.helper.leave(room_id, bob, tok=bob_tok)["event_id"] + + # Alice sends some more stuff + self.helper.send(room_id, "after", tok=alice_tok)["event_id"] + self.helper.send_state(room_id, "test_state", {"body": "after"}, tok=alice_tok)[ + "event_id" + ] + + # And now, Bob resyncs. + filter_dict: JsonDict = {"room": {"include_leave": True}} + if empty_timeline: + filter_dict["room"]["timeline"] = {"limit": 0} + sync_room_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + bob_requester, + generate_sync_config( + bob, filter_collection=FilterCollection(self.hs, filter_dict) + ), + since_token=None if initial_sync else initial_sync_result.next_batch, + ) + ).archived[0] + + if empty_timeline: + # The timeline should be empty + self.assertEqual(sync_room_result.timeline.events, []) + + # And the state should include the leave event... + self.assertEqual( + sync_room_result.state[("m.room.member", bob)].event_id, leave_event + ) + # ... and the state change before he left. + self.assertEqual( + sync_room_result.state[("test_state", "")].event_id, before_state_event + ) + else: + # The last three events in the timeline should be those leading up to the + # leave + self.assertEqual( + [e.event_id for e in sync_room_result.timeline.events[-3:]], + [before_message_event, before_state_event, leave_event], + ) + # ... And the state should be empty + self.assertEqual(sync_room_result.state, {}) + + def _patch_get_latest_events(self, latest_events: List[str]) -> ContextManager: + """Monkey-patch `get_prev_events_for_room` + + Returns a context manager which will replace the implementation of + `get_prev_events_for_room` with one which returns `latest_events`. + """ + return patch.object( + self.hs.get_datastores().main, + "get_prev_events_for_room", + new_callable=AsyncMock, + return_value=latest_events, + ) + def test_call_invite_in_public_room_not_returned(self) -> None: user = self.register_user("alice", "password") tok = self.login(user, "password") @@ -401,14 +571,26 @@ _request_key = 0 def generate_sync_config( - user_id: str, device_id: Optional[str] = "device_id" + user_id: str, + device_id: Optional[str] = "device_id", + filter_collection: Optional[FilterCollection] = None, ) -> SyncConfig: - """Generate a sync config (with a unique request key).""" + """Generate a sync config (with a unique request key). + + Args: + user_id: user who is syncing. + device_id: device that is syncing. Defaults to "device_id". + filter_collection: filter to apply. Defaults to the default filter (ie, + return everything, with a default limit) + """ + if filter_collection is None: + filter_collection = Filtering(Mock()).DEFAULT_FILTER_COLLECTION + global _request_key _request_key += 1 return SyncConfig( user=UserID.from_string(user_id), - filter_collection=Filtering(Mock()).DEFAULT_FILTER_COLLECTION, + filter_collection=filter_collection, is_guest=False, request_key=("request_key", _request_key), device_id=device_id, diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index daa68d78b9..fe00afe198 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -170,8 +170,8 @@ class RestHelper: targ: Optional[str] = None, expect_code: int = HTTPStatus.OK, tok: Optional[str] = None, - ) -> None: - self.change_membership( + ) -> JsonDict: + return self.change_membership( room=room, src=src, targ=targ, @@ -189,8 +189,8 @@ class RestHelper: appservice_user_id: Optional[str] = None, expect_errcode: Optional[Codes] = None, expect_additional_fields: Optional[dict] = None, - ) -> None: - self.change_membership( + ) -> JsonDict: + return self.change_membership( room=room, src=user, targ=user, @@ -242,8 +242,8 @@ class RestHelper: user: Optional[str] = None, expect_code: int = HTTPStatus.OK, tok: Optional[str] = None, - ) -> None: - self.change_membership( + ) -> JsonDict: + return self.change_membership( room=room, src=user, targ=user, @@ -282,7 +282,7 @@ class RestHelper: expect_code: int = HTTPStatus.OK, expect_errcode: Optional[str] = None, expect_additional_fields: Optional[dict] = None, - ) -> None: + ) -> JsonDict: """ Send a membership state event into a room. @@ -298,6 +298,9 @@ class RestHelper: using an application service access token in `tok`. expect_code: The expected HTTP response code expect_errcode: The expected Matrix error code + + Returns: + The JSON response """ temp_id = self.auth_user_id self.auth_user_id = src @@ -356,6 +359,7 @@ class RestHelper: ) self.auth_user_id = temp_id + return channel.json_body def send( self, From 230b709d9d8b09fd4884be1265db535263975e35 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 4 Apr 2024 13:14:24 +0100 Subject: [PATCH 5/7] `/sync`: fix bug in calculating `state` response (#16930) Fix a long-standing issue which could cause state to be omitted from the sync response if the last event was filtered out. Fixes: https://github.com/element-hq/synapse/issues/16928 --- changelog.d/16930.bugfix | 1 + synapse/handlers/sync.py | 54 ++++++------------------- tests/handlers/test_sync.py | 80 +++++++++++++++++++++++++++++++++++++ 3 files changed, 94 insertions(+), 41 deletions(-) create mode 100644 changelog.d/16930.bugfix diff --git a/changelog.d/16930.bugfix b/changelog.d/16930.bugfix new file mode 100644 index 0000000000..21f964ef97 --- /dev/null +++ b/changelog.d/16930.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug which could cause state to be omitted from `/sync` responses when certain events are filtered out of the timeline. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 7fcd54ac55..773e291aa8 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1187,15 +1187,14 @@ class SyncHandler: await_full_state = True lazy_load_members = False - if batch: - state_at_timeline_end = ( - await self._state_storage_controller.get_state_ids_for_event( - batch.events[-1].event_id, - state_filter=state_filter, - await_full_state=await_full_state, - ) - ) + state_at_timeline_end = await self.get_state_at( + room_id, + stream_position=end_token, + state_filter=state_filter, + await_full_state=await_full_state, + ) + if batch: state_at_timeline_start = ( await self._state_storage_controller.get_state_ids_for_event( batch.events[0].event_id, @@ -1204,13 +1203,6 @@ class SyncHandler: ) ) else: - state_at_timeline_end = await self.get_state_at( - room_id, - stream_position=end_token, - state_filter=state_filter, - await_full_state=await_full_state, - ) - state_at_timeline_start = state_at_timeline_end state_ids = _calculate_state( @@ -1305,23 +1297,12 @@ class SyncHandler: await_full_state=await_full_state, ) - if batch: - state_at_timeline_end = ( - await self._state_storage_controller.get_state_ids_for_event( - batch.events[-1].event_id, - state_filter=state_filter, - await_full_state=await_full_state, - ) - ) - else: - # We can get here if the user has ignored the senders of all - # the recent events. - state_at_timeline_end = await self.get_state_at( - room_id, - stream_position=end_token, - state_filter=state_filter, - await_full_state=await_full_state, - ) + state_at_timeline_end = await self.get_state_at( + room_id, + stream_position=end_token, + state_filter=state_filter, + await_full_state=await_full_state, + ) state_ids = _calculate_state( timeline_contains=timeline_state, @@ -2847,15 +2828,6 @@ def _calculate_state( # timeline. We have no way to represent that in the /sync response, and we don't # even try; it is ether omitted or plonked into `state` as if it were at the start # of the timeline, depending on what else is in the timeline.) - # - # ---------- - # - # Aside 2: it's worth noting that `timeline_end`, as provided to us, is actually - # the state *before* the final event in the timeline. In other words: if the final - # event in the timeline is a state event, it won't be included in `timeline_end`. - # However, that doesn't matter here, because the only difference can be in that - # one piece of state, and by definition that event is in the timeline, so we - # don't need to include it in the `state` section. state_ids = ( (timeline_end_ids | timeline_start_ids) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 897c52c785..5d8e886541 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -355,6 +355,86 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): [s2_event], ) + def test_state_includes_changes_on_forks_when_events_excluded(self) -> None: + """A variation on the previous test, but where one event is filtered + + The DAG is the same as the previous test, but E4 is excluded by the filter. + + E1 + ↗ ↖ + | S2 + | ↑ + --|------|---- + | | + E3 | + ↖ / + (E4) + + """ + + alice = self.register_user("alice", "password") + alice_tok = self.login(alice, "password") + alice_requester = create_requester(alice) + room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok) + + # Do an initial sync as Alice to get a known starting point. + initial_sync_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + alice_requester, generate_sync_config(alice) + ) + ) + last_room_creation_event_id = ( + initial_sync_result.joined[0].timeline.events[-1].event_id + ) + + # Send a state event, and a regular event, both using the same prev ID + with self._patch_get_latest_events([last_room_creation_event_id]): + s2_event = self.helper.send_state(room_id, "s2", {}, tok=alice_tok)[ + "event_id" + ] + e3_event = self.helper.send(room_id, "e3", tok=alice_tok)["event_id"] + + # Send a final event, joining the two branches of the dag + self.helper.send(room_id, "e4", type="not_a_normal_message", tok=alice_tok)[ + "event_id" + ] + + # do an incremental sync, with a filter that will only return E3, excluding S2 + # and E4. + incremental_sync = self.get_success( + self.sync_handler.wait_for_sync_for_user( + alice_requester, + generate_sync_config( + alice, + filter_collection=FilterCollection( + self.hs, + { + "room": { + "timeline": { + "limit": 1, + "not_types": ["not_a_normal_message"], + } + } + }, + ), + ), + since_token=initial_sync_result.next_batch, + ) + ) + + # The state event should appear in the 'state' section of the response. + room_sync = incremental_sync.joined[0] + self.assertEqual(room_sync.room_id, room_id) + self.assertTrue(room_sync.timeline.limited) + self.assertEqual( + [e.event_id for e in room_sync.timeline.events], + [e3_event], + ) + self.assertEqual( + [e.event_id for e in room_sync.state.values()], + [s2_event], + ) + @parameterized.expand( [ (False, False), From 0e68e9b7f4dd64d1b4b28feb4050e4b4fd85fb9d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 4 Apr 2024 17:15:35 +0100 Subject: [PATCH 6/7] Fix bug in calculating state for non-gappy syncs (#16942) Unfortunately, the optimisation we applied here for non-gappy syncs is not actually valid. Fixes https://github.com/element-hq/synapse/issues/16941. ~~Based on https://github.com/element-hq/synapse/pull/16930.~~ Requires https://github.com/matrix-org/sytest/pull/1374. --- changelog.d/16930.bugfix | 2 +- changelog.d/16932.bugfix | 2 +- changelog.d/16942.bugfix | 1 + synapse/handlers/sync.py | 89 +++++++++++++----------------- tests/handlers/test_sync.py | 105 ++++++++++++++++++++++++++++++++++++ 5 files changed, 144 insertions(+), 55 deletions(-) create mode 100644 changelog.d/16942.bugfix diff --git a/changelog.d/16930.bugfix b/changelog.d/16930.bugfix index 21f964ef97..99ed435d75 100644 --- a/changelog.d/16930.bugfix +++ b/changelog.d/16930.bugfix @@ -1 +1 @@ -Fix a long-standing bug which could cause state to be omitted from `/sync` responses when certain events are filtered out of the timeline. +Fix various long-standing bugs which could cause incorrect state to be returned from `/sync` in certain situations. diff --git a/changelog.d/16932.bugfix b/changelog.d/16932.bugfix index 624388ea8e..99ed435d75 100644 --- a/changelog.d/16932.bugfix +++ b/changelog.d/16932.bugfix @@ -1 +1 @@ -Fix a long-standing bug which could cause incorrect state to be returned from `/sync` for rooms where the user has left. \ No newline at end of file +Fix various long-standing bugs which could cause incorrect state to be returned from `/sync` in certain situations. diff --git a/changelog.d/16942.bugfix b/changelog.d/16942.bugfix new file mode 100644 index 0000000000..99ed435d75 --- /dev/null +++ b/changelog.d/16942.bugfix @@ -0,0 +1 @@ +Fix various long-standing bugs which could cause incorrect state to be returned from `/sync` in certain situations. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 773e291aa8..554c820f79 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1195,6 +1195,8 @@ class SyncHandler: ) if batch: + # Strictly speaking, this returns the state *after* the first event in the + # timeline, but that is good enough here. state_at_timeline_start = ( await self._state_storage_controller.get_state_ids_for_event( batch.events[0].event_id, @@ -1257,25 +1259,25 @@ class SyncHandler: await_full_state = True lazy_load_members = False - if batch.limited: - if batch: - state_at_timeline_start = ( - await self._state_storage_controller.get_state_ids_for_event( - batch.events[0].event_id, - state_filter=state_filter, - await_full_state=await_full_state, - ) - ) - else: - # We can get here if the user has ignored the senders of all - # the recent events. - state_at_timeline_start = await self.get_state_at( - room_id, - stream_position=end_token, + if batch: + state_at_timeline_start = ( + await self._state_storage_controller.get_state_ids_for_event( + batch.events[0].event_id, state_filter=state_filter, await_full_state=await_full_state, ) + ) + else: + # We can get here if the user has ignored the senders of all + # the recent events. + state_at_timeline_start = await self.get_state_at( + room_id, + stream_position=end_token, + state_filter=state_filter, + await_full_state=await_full_state, + ) + if batch.limited: # for now, we disable LL for gappy syncs - see # https://github.com/vector-im/riot-web/issues/7211#issuecomment-419976346 # N.B. this slows down incr syncs as we are now processing way @@ -1290,47 +1292,28 @@ class SyncHandler: # about them). state_filter = StateFilter.all() - state_at_previous_sync = await self.get_state_at( - room_id, - stream_position=since_token, - state_filter=state_filter, - await_full_state=await_full_state, - ) + state_at_previous_sync = await self.get_state_at( + room_id, + stream_position=since_token, + state_filter=state_filter, + await_full_state=await_full_state, + ) - state_at_timeline_end = await self.get_state_at( - room_id, - stream_position=end_token, - state_filter=state_filter, - await_full_state=await_full_state, - ) + state_at_timeline_end = await self.get_state_at( + room_id, + stream_position=end_token, + state_filter=state_filter, + await_full_state=await_full_state, + ) - state_ids = _calculate_state( - timeline_contains=timeline_state, - timeline_start=state_at_timeline_start, - timeline_end=state_at_timeline_end, - previous_timeline_end=state_at_previous_sync, - lazy_load_members=lazy_load_members, - ) - else: - state_ids = {} - if lazy_load_members: - if members_to_fetch and batch.events: - # We're returning an incremental sync, with no - # "gap" since the previous sync, so normally there would be - # no state to return. - # But we're lazy-loading, so the client might need some more - # member events to understand the events in this timeline. - # So we fish out all the member events corresponding to the - # timeline here. The caller will then dedupe any redundant ones. + state_ids = _calculate_state( + timeline_contains=timeline_state, + timeline_start=state_at_timeline_start, + timeline_end=state_at_timeline_end, + previous_timeline_end=state_at_previous_sync, + lazy_load_members=lazy_load_members, + ) - state_ids = await self._state_storage_controller.get_state_ids_for_event( - batch.events[0].event_id, - # we only want members! - state_filter=StateFilter.from_types( - (EventTypes.Member, member) for member in members_to_fetch - ), - await_full_state=False, - ) return state_ids async def _find_missing_partial_state_memberships( diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 5d8e886541..57e14d79ca 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -435,6 +435,111 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): [s2_event], ) + def test_state_includes_changes_on_ungappy_syncs(self) -> None: + """Test `state` where the sync is not gappy. + + We start with a DAG like this: + + E1 + ↗ ↖ + | S2 + | + --|--- + | + E3 + + ... and initialsync with `limit=1`, represented by the horizontal dashed line. + At this point, we do not expect S2 to appear in the response at all (since + it is excluded from the timeline by the `limit`, and the state is based on the + state after the most recent event before the sync token (E3), which doesn't + include S2. + + Now more events arrive, and we do an incremental sync: + + E1 + ↗ ↖ + | S2 + | ↑ + E3 | + ↑ | + --|------|---- + | | + E4 | + ↖ / + E5 + + This is the last chance for us to tell the client about S2, so it *must* be + included in the response. + """ + alice = self.register_user("alice", "password") + alice_tok = self.login(alice, "password") + alice_requester = create_requester(alice) + room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok) + + # Do an initial sync to get a known starting point. + initial_sync_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + alice_requester, generate_sync_config(alice) + ) + ) + last_room_creation_event_id = ( + initial_sync_result.joined[0].timeline.events[-1].event_id + ) + + # Send a state event, and a regular event, both using the same prev ID + with self._patch_get_latest_events([last_room_creation_event_id]): + s2_event = self.helper.send_state(room_id, "s2", {}, tok=alice_tok)[ + "event_id" + ] + e3_event = self.helper.send(room_id, "e3", tok=alice_tok)["event_id"] + + # Another initial sync, with limit=1 + initial_sync_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + alice_requester, + generate_sync_config( + alice, + filter_collection=FilterCollection( + self.hs, {"room": {"timeline": {"limit": 1}}} + ), + ), + ) + ) + room_sync = initial_sync_result.joined[0] + self.assertEqual(room_sync.room_id, room_id) + self.assertEqual( + [e.event_id for e in room_sync.timeline.events], + [e3_event], + ) + self.assertNotIn(s2_event, [e.event_id for e in room_sync.state.values()]) + + # More events, E4 and E5 + with self._patch_get_latest_events([e3_event]): + e4_event = self.helper.send(room_id, "e4", tok=alice_tok)["event_id"] + e5_event = self.helper.send(room_id, "e5", tok=alice_tok)["event_id"] + + # Now incremental sync + incremental_sync = self.get_success( + self.sync_handler.wait_for_sync_for_user( + alice_requester, + generate_sync_config(alice), + since_token=initial_sync_result.next_batch, + ) + ) + + # The state event should appear in the 'state' section of the response. + room_sync = incremental_sync.joined[0] + self.assertEqual(room_sync.room_id, room_id) + self.assertFalse(room_sync.timeline.limited) + self.assertEqual( + [e.event_id for e in room_sync.timeline.events], + [e4_event, e5_event], + ) + self.assertEqual( + [e.event_id for e in room_sync.state.values()], + [s2_event], + ) + @parameterized.expand( [ (False, False), From 5360baeb6439366c29d55038da7f677c64eea4bf Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 5 Apr 2024 12:46:34 +0100 Subject: [PATCH 7/7] Pull out fewer receipts from DB when doing push (#17049) Before we were pulling out *all* read receipts for a user for every event we pushed. Instead let's only pull out the relevant receipts. This also pulled out the event rows for each receipt, causing load on the events table. --- changelog.d/17049.misc | 1 + .../databases/main/event_push_actions.py | 124 ++++++++++++++---- 2 files changed, 103 insertions(+), 22 deletions(-) create mode 100644 changelog.d/17049.misc diff --git a/changelog.d/17049.misc b/changelog.d/17049.misc new file mode 100644 index 0000000000..f71a6473a2 --- /dev/null +++ b/changelog.d/17049.misc @@ -0,0 +1 @@ +Improve database performance by reducing number of receipts fetched when sending push notifications. diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 3a5666cd9b..40bf000e9c 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -106,7 +106,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.stream import StreamWorkerStore -from synapse.types import JsonDict +from synapse.types import JsonDict, StrCollection from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -859,37 +859,86 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas return await self.db_pool.runInteraction("get_push_action_users_in_range", f) - def _get_receipts_by_room_txn( - self, txn: LoggingTransaction, user_id: str + def _get_receipts_for_room_and_threads_txn( + self, + txn: LoggingTransaction, + user_id: str, + room_ids: StrCollection, + thread_ids: StrCollection, ) -> Dict[str, _RoomReceipt]: """ - Generate a map of room ID to the latest stream ordering that has been - read by the given user. + Get (private) read receipts for a user in each of the given room IDs + and thread IDs. - Args: - txn: - user_id: The user to fetch receipts for. + Note: The corresponding room ID for each thread must appear in + `room_ids` arg. Returns: A map including all rooms the user is in with a receipt. It maps room IDs to _RoomReceipt instances """ - receipt_types_clause, args = make_in_list_sql_clause( + + receipt_types_clause, receipts_args = make_in_list_sql_clause( self.database_engine, "receipt_type", (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), ) + thread_ids_clause, thread_ids_args = make_in_list_sql_clause( + self.database_engine, + "thread_id", + thread_ids, + ) + + room_ids_clause, room_ids_args = make_in_list_sql_clause( + self.database_engine, + "room_id", + room_ids, + ) + + # We use the union of two (almost identical) queries here, the first to + # fetch the specific thread receipts and the second to fetch the + # unthreaded receipts. + # + # This SQL is optimized to use the indices we have on + # `receipts_linearized`. + # + # We compare room ID and thread IDs independently due to the above, + # which means that this query might return more rows than we need if the + # same thread ID appears across different rooms (e.g. 'main' thread ID). + # This doesn't cause any logic issues, and isn't a performance concern + # given this function generally gets called with only one room and + # thread ID. sql = f""" SELECT room_id, thread_id, MAX(stream_ordering) FROM receipts_linearized INNER JOIN events USING (room_id, event_id) WHERE {receipt_types_clause} + AND {thread_ids_clause} + AND {room_ids_clause} + AND user_id = ? + GROUP BY room_id, thread_id + + UNION ALL + + SELECT room_id, thread_id, MAX(stream_ordering) + FROM receipts_linearized + INNER JOIN events USING (room_id, event_id) + WHERE {receipt_types_clause} + AND {room_ids_clause} + AND thread_id IS NULL AND user_id = ? GROUP BY room_id, thread_id """ - args.extend((user_id,)) + args = list(receipts_args) + args.extend(thread_ids_args) + args.extend(room_ids_args) + args.append(user_id) + args.extend(receipts_args) + args.extend(room_ids_args) + args.append(user_id) + txn.execute(sql, args) result: Dict[str, _RoomReceipt] = {} @@ -925,12 +974,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas The list will have between 0~limit entries. """ - receipts_by_room = await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_http_receipts", - self._get_receipts_by_room_txn, - user_id=user_id, - ) - def get_push_actions_txn( txn: LoggingTransaction, ) -> List[Tuple[str, str, str, int, str, bool]]: @@ -952,6 +995,27 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas "get_unread_push_actions_for_user_in_range_http", get_push_actions_txn ) + room_ids = set() + thread_ids = [] + for ( + _, + room_id, + thread_id, + _, + _, + _, + ) in push_actions: + room_ids.add(room_id) + thread_ids.append(thread_id) + + receipts_by_room = await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_http_receipts", + self._get_receipts_for_room_and_threads_txn, + user_id=user_id, + room_ids=room_ids, + thread_ids=thread_ids, + ) + notifs = [ HttpPushAction( event_id=event_id, @@ -998,12 +1062,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas The list will have between 0~limit entries. """ - receipts_by_room = await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_email_receipts", - self._get_receipts_by_room_txn, - user_id=user_id, - ) - def get_push_actions_txn( txn: LoggingTransaction, ) -> List[Tuple[str, str, str, int, str, bool, int]]: @@ -1026,6 +1084,28 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas "get_unread_push_actions_for_user_in_range_email", get_push_actions_txn ) + room_ids = set() + thread_ids = [] + for ( + _, + room_id, + thread_id, + _, + _, + _, + _, + ) in push_actions: + room_ids.add(room_id) + thread_ids.append(thread_id) + + receipts_by_room = await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_email_receipts", + self._get_receipts_for_room_and_threads_txn, + user_id=user_id, + room_ids=room_ids, + thread_ids=thread_ids, + ) + # Make a list of dicts from the two sets of results. notifs = [ EmailPushAction(