diff --git a/synapse/handlers/sliding_sync/__init__.py b/synapse/handlers/sliding_sync/__init__.py index 9fcc68ff25..4915682f6e 100644 --- a/synapse/handlers/sliding_sync/__init__.py +++ b/synapse/handlers/sliding_sync/__init__.py @@ -14,7 +14,7 @@ import logging from itertools import chain -from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple +from typing import TYPE_CHECKING, AbstractSet, Dict, List, Mapping, Optional, Set, Tuple from prometheus_client import Histogram from typing_extensions import assert_never @@ -988,6 +988,10 @@ class SlidingSyncHandler: include_others=required_state_filter.include_others, ) + # The required state map to store in the room sync config, if it has + # changed. + changed_required_state_map: Optional[Mapping[str, AbstractSet[str]]] = None + # We can return all of the state that was requested if this was the first # time we've sent the room down this connection. room_state: StateMap[EventBase] = {} @@ -1001,6 +1005,29 @@ class SlidingSyncHandler: else: assert from_bound is not None + if prev_room_sync_config is not None: + # Check if there are any changes to the required state config + # that we need to handle. + changed_required_state_map, added_state_filter = ( + _required_state_changes( + user.to_string(), + prev_room_sync_config, + room_sync_config, + room_state_delta_id_map, + ) + ) + + if added_state_filter: + # Some state entries got added, so we pull out the current + # state for them. If we don't do this we'd only send down new deltas. + state_ids = await self.get_current_state_ids_at( + room_id=room_id, + room_membership_for_user_at_to_token=room_membership_for_user_at_to_token, + state_filter=added_state_filter, + to_token=to_token, + ) + room_state_delta_id_map.update(state_ids) + events = await self.store.get_events( state_filter.filter_state(room_state_delta_id_map).values() ) @@ -1083,6 +1110,10 @@ class SlidingSyncHandler: prev_room_sync_config = previous_connection_state.room_configs.get(room_id) # Record the `room_sync_config` if we're `ignore_timeline_bound` (which means # that the `timeline_limit` has increased) + room_sync_required_state_map_to_persist = room_sync_config.required_state_map + if changed_required_state_map: + room_sync_required_state_map_to_persist = changed_required_state_map + if ignore_timeline_bound: # FIXME: We signal the fact that we're sending down more events to # the client by setting `unstable_expanded_timeline` to true (see @@ -1091,7 +1122,7 @@ class SlidingSyncHandler: new_connection_state.room_configs[room_id] = RoomSyncConfig( timeline_limit=room_sync_config.timeline_limit, - required_state_map=room_sync_config.required_state_map, + required_state_map=room_sync_required_state_map_to_persist, ) elif prev_room_sync_config is not None: # If the result is `limited` then we need to record that the @@ -1120,10 +1151,14 @@ class SlidingSyncHandler: ): new_connection_state.room_configs[room_id] = RoomSyncConfig( timeline_limit=room_sync_config.timeline_limit, - required_state_map=room_sync_config.required_state_map, + required_state_map=room_sync_required_state_map_to_persist, ) - # TODO: Record changes in required_state. + elif changed_required_state_map is not None: + new_connection_state.room_configs[room_id] = RoomSyncConfig( + timeline_limit=room_sync_config.timeline_limit, + required_state_map=room_sync_required_state_map_to_persist, + ) else: new_connection_state.room_configs[room_id] = room_sync_config @@ -1242,3 +1277,167 @@ class SlidingSyncHandler: return new_bump_event_pos.stream return None + + +def _required_state_changes( + user_id: str, + previous_room_config: "RoomSyncConfig", + room_sync_config: RoomSyncConfig, + state_deltas: StateMap[str], +) -> Tuple[Optional[Mapping[str, AbstractSet[str]]], StateFilter]: + """Calculates the changes between the required state room config from the + previous requests compared with the current request. + + This does two things. First, it calculates if we need to update the room + config due to changes to required state. Secondly, it works out which state + entries we need to pull from current state and return due to the state entry + now appearing in the required state when it previously wasn't (on top of the + state deltas). + + This function tries to ensure to handle the case where a state entry is + added, removed and then added again to the required state. In that case we + only want to re-send that entry down sync if it has changed. + + Returns: + A 2-tuple of updated required state config and the state filter to use + to fetch extra current state that we need to return. + """ + + prev_required_state_map = previous_room_config.required_state_map + request_required_state_map = room_sync_config.required_state_map + + if prev_required_state_map == request_required_state_map: + # There has been no change. Return immediately. + return None, StateFilter.none() + + prev_wildcard = prev_required_state_map.get(StateValues.WILDCARD, set()) + request_wildcard = request_required_state_map.get(StateValues.WILDCARD, set()) + + # If a event type wildcard has been added or removed we don't try and do + # anything fancy, and instead always update the effective room required + # state config to match the request. + if request_wildcard - prev_wildcard: + # Some keys were added, so we need to fetch everything + return request_required_state_map, StateFilter.all() + if prev_wildcard - request_wildcard: + # Keys were only removed, so we don't have to fetch everything. + return request_required_state_map, StateFilter.none() + + # Contains updates to the required state map compared with the previous room + # config. This has the same format as `RoomSyncConfig.required_state` + changes: Dict[str, AbstractSet[str]] = {} + + # The set of types/state keys that we need to fetch and return to the + # client. Passed to `StateFilter.from_types(...)` + added: List[Tuple[str, Optional[str]]] = [] + + # First we calculate what, if anything, has been *added*. + for event_type in ( + prev_required_state_map.keys() | request_required_state_map.keys() + ): + old_state_keys = prev_required_state_map.get(event_type, set()) + request_state_keys = request_required_state_map.get(event_type, set()) + + if old_state_keys == request_state_keys: + # No change to this type + continue + + if not request_state_keys - old_state_keys: + # Nothing *added*, so we skip. Removals happen below. + continue + + # Always update changes to include the newly added keys + changes[event_type] = request_state_keys + + # Record the new state keys to fetch for this type. + if StateValues.WILDCARD in request_state_keys: + # If we have added a wildcard then we always just fetch everything. + added.append((event_type, None)) + else: + for state_key in request_state_keys - old_state_keys: + if state_key == StateValues.ME: + added.append((event_type, user_id)) + elif state_key == StateValues.LAZY: + # We handle lazy loading separately, so don't need to + # explicitly add anything here. + pass + else: + added.append((event_type, state_key)) + + added_state_filter = StateFilter.from_types(added) + + # Convert the list of state deltas to map from type to state_keys that have + # changed. + changed_types_to_state_keys: Dict[str, Set[str]] = {} + for event_type, state_key in state_deltas: + changed_types_to_state_keys.setdefault(event_type, set()).add(state_key) + + # Figure out what changes we need to apply to the effective required state + # config. + for event_type, changed_state_keys in changed_types_to_state_keys.items(): + old_state_keys = prev_required_state_map.get(event_type, set()) + request_state_keys = request_required_state_map.get(event_type, set()) + + if old_state_keys == request_state_keys: + # No change. + continue + + if request_state_keys - old_state_keys: + # We've expanded the set of state keys, so we just clobber the + # current set with the new set. + # + # We could also ensure that we keep entries where the state hasn't + # changed, but are no longer in the requested required state, but + # that's a sufficient edge case that we can ignore (as its only a + # performance optimization). + changes[event_type] = request_state_keys + continue + + old_state_key_wildcard = StateValues.WILDCARD in old_state_keys + request_state_key_wildcard = StateValues.WILDCARD in request_state_keys + + if old_state_key_wildcard != request_state_key_wildcard: + # If a wildcard has been added or removed we always update the + # required state when any state with the same type has changed. + changes[event_type] = request_state_keys + continue + + old_state_key_lazy = StateValues.LAZY in old_state_keys + request_state_key_lazy = StateValues.LAZY in request_state_keys + + if old_state_key_lazy != request_state_key_lazy: + # If a "$LAZY" has been added or removed we always update the + # required state for simplicity. + changes[event_type] = request_state_keys + continue + + # Handle "$ME" values by replacing state keys that match the user ID. + if user_id in changed_state_keys: + changed_state_keys.discard(user_id) + changed_state_keys.add(StateValues.ME) + + # At this point there are no wildcards and no additions to the set of + # state keys requested, only deletions. + # + # We only remove state keys from the effective state if they've been + # removed from the request *and* the state has changed. This ensures + # that if a client removes and then readds a state key, we only send + # down the associated current state event if its changed (rather than + # sending down the same event twice). + invalidated = (old_state_keys - request_state_keys) & changed_state_keys + if invalidated: + changes[event_type] = old_state_keys - invalidated + + if changes: + # Update the required state config based on the changes. + new_required_state_map = dict(prev_required_state_map) + for event_type, state_keys in changes.items(): + if state_keys: + new_required_state_map[event_type] = state_keys + else: + # Remove entries with empty state keys. + new_required_state_map.pop(event_type, None) + + return new_required_state_map, added_state_filter + else: + return None, added_state_filter diff --git a/synapse/storage/databases/main/sliding_sync.py b/synapse/storage/databases/main/sliding_sync.py index f2df37fec1..7b357c1ffe 100644 --- a/synapse/storage/databases/main/sliding_sync.py +++ b/synapse/storage/databases/main/sliding_sync.py @@ -386,8 +386,8 @@ class SlidingSyncStore(SQLBaseStore): required_state_map: Dict[int, Dict[str, Set[str]]] = {} for row in rows: state = required_state_map[row[0]] = {} - for event_type, state_keys in db_to_json(row[1]): - state[event_type] = set(state_keys) + for event_type, state_key in db_to_json(row[1]): + state.setdefault(event_type, set()).add(state_key) # Get all the room configs, looking up the required state from the map # above. diff --git a/synapse/types/state.py b/synapse/types/state.py index 1141c4b5c1..060a5b6183 100644 --- a/synapse/types/state.py +++ b/synapse/types/state.py @@ -616,6 +616,11 @@ class StateFilter: return False + def __bool__(self) -> bool: + if self.include_others: + return True + return bool(self.types) + _ALL_STATE_FILTER = StateFilter(types=immutabledict(), include_others=True) _ALL_NON_MEMBER_STATE_FILTER = StateFilter( diff --git a/tests/rest/client/sliding_sync/test_rooms_required_state.py b/tests/rest/client/sliding_sync/test_rooms_required_state.py index 91ac6c5a0e..0d98a6d889 100644 --- a/tests/rest/client/sliding_sync/test_rooms_required_state.py +++ b/tests/rest/client/sliding_sync/test_rooms_required_state.py @@ -862,3 +862,268 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase): exact=True, message=f"Expected only fully-stated rooms to show up for test_key={list_key}.", ) + + def test_rooms_required_state_expand(self) -> None: + """Test that when we expand the required state argument we get the + expanded state, and not just the changes to the new expanded.""" + + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + # Create a room with a room name. + room_id1 = self.helper.create_room_as( + user1_id, tok=user1_tok, extra_content={"name": "Foo"} + ) + + # Only request the state event to begin with + sync_body = { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [ + [EventTypes.Create, ""], + ], + "timeline_limit": 1, + } + } + } + response_body, from_token = self.do_sync(sync_body, tok=user1_tok) + + state_map = self.get_success( + self.storage_controllers.state.get_current_state(room_id1) + ) + + self._assertRequiredStateIncludes( + response_body["rooms"][room_id1]["required_state"], + { + state_map[(EventTypes.Create, "")], + }, + exact=True, + ) + + # Send a message so the room comes down sync. + self.helper.send(room_id1, "msg", tok=user1_tok) + + # Update the sliding sync requests to include the room name + sync_body["lists"]["foo-list"]["required_state"] = [ + [EventTypes.Create, ""], + [EventTypes.Name, ""], + ] + response_body, from_token = self.do_sync( + sync_body, since=from_token, tok=user1_tok + ) + + # We should see the room name, even though there haven't been any + # changes. + self._assertRequiredStateIncludes( + response_body["rooms"][room_id1]["required_state"], + { + state_map[(EventTypes.Name, "")], + }, + exact=True, + ) + + # Send a message so the room comes down sync. + self.helper.send(room_id1, "msg", tok=user1_tok) + + # We should not see any state changes. + response_body, from_token = self.do_sync( + sync_body, since=from_token, tok=user1_tok + ) + self.assertIsNone(response_body["rooms"][room_id1].get("required_state")) + + def test_rooms_required_state_expand_retract_expand(self) -> None: + """Test that when expanding, retracting and then expanding the required + state, we get the changes that happened.""" + + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + # Create a room with a room name. + room_id1 = self.helper.create_room_as( + user1_id, tok=user1_tok, extra_content={"name": "Foo"} + ) + + # Only request the state event to begin with + sync_body = { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [ + [EventTypes.Create, ""], + ], + "timeline_limit": 1, + } + } + } + response_body, from_token = self.do_sync(sync_body, tok=user1_tok) + + state_map = self.get_success( + self.storage_controllers.state.get_current_state(room_id1) + ) + + self._assertRequiredStateIncludes( + response_body["rooms"][room_id1]["required_state"], + { + state_map[(EventTypes.Create, "")], + }, + exact=True, + ) + + # Send a message so the room comes down sync. + self.helper.send(room_id1, "msg", tok=user1_tok) + + # Update the sliding sync requests to include the room name + sync_body["lists"]["foo-list"]["required_state"] = [ + [EventTypes.Create, ""], + [EventTypes.Name, ""], + ] + response_body, from_token = self.do_sync( + sync_body, since=from_token, tok=user1_tok + ) + + # We should see the room name, even though there haven't been any + # changes. + self._assertRequiredStateIncludes( + response_body["rooms"][room_id1]["required_state"], + { + state_map[(EventTypes.Name, "")], + }, + exact=True, + ) + + # Update the room name + self.helper.send_state( + room_id1, "m.room.name", {"name": "Bar"}, state_key="", tok=user1_tok + ) + + # Update the sliding sync requests to exclude the room name again + sync_body["lists"]["foo-list"]["required_state"] = [ + [EventTypes.Create, ""], + ] + response_body, from_token = self.do_sync( + sync_body, since=from_token, tok=user1_tok + ) + + # We should not see the updated room name in state (though it will be in + # the timeline). + self.assertIsNone(response_body["rooms"][room_id1].get("required_state")) + + # Send a message so the room comes down sync. + self.helper.send(room_id1, "msg", tok=user1_tok) + + # Update the sliding sync requests to include the room name again + sync_body["lists"]["foo-list"]["required_state"] = [ + [EventTypes.Create, ""], + [EventTypes.Name, ""], + ] + response_body, from_token = self.do_sync( + sync_body, since=from_token, tok=user1_tok + ) + + # We should see the *new* room name, even though there haven't been any + # changes. + state_map = self.get_success( + self.storage_controllers.state.get_current_state(room_id1) + ) + self._assertRequiredStateIncludes( + response_body["rooms"][room_id1]["required_state"], + { + state_map[(EventTypes.Name, "")], + }, + exact=True, + ) + + def test_rooms_required_state_expand_deduplicate(self) -> None: + """Test that when expanding, retracting and then expanding the required + state, we don't get the state down again if it hasn't changed""" + + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + # Create a room with a room name. + room_id1 = self.helper.create_room_as( + user1_id, tok=user1_tok, extra_content={"name": "Foo"} + ) + + # Only request the state event to begin with + sync_body = { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [ + [EventTypes.Create, ""], + ], + "timeline_limit": 1, + } + } + } + response_body, from_token = self.do_sync(sync_body, tok=user1_tok) + + state_map = self.get_success( + self.storage_controllers.state.get_current_state(room_id1) + ) + + self._assertRequiredStateIncludes( + response_body["rooms"][room_id1]["required_state"], + { + state_map[(EventTypes.Create, "")], + }, + exact=True, + ) + + # Send a message so the room comes down sync. + self.helper.send(room_id1, "msg", tok=user1_tok) + + # Update the sliding sync requests to include the room name + sync_body["lists"]["foo-list"]["required_state"] = [ + [EventTypes.Create, ""], + [EventTypes.Name, ""], + ] + response_body, from_token = self.do_sync( + sync_body, since=from_token, tok=user1_tok + ) + + # We should see the room name, even though there haven't been any + # changes. + self._assertRequiredStateIncludes( + response_body["rooms"][room_id1]["required_state"], + { + state_map[(EventTypes.Name, "")], + }, + exact=True, + ) + + # Send a message so the room comes down sync. + self.helper.send(room_id1, "msg", tok=user1_tok) + + # Update the sliding sync requests to exclude the room name again + sync_body["lists"]["foo-list"]["required_state"] = [ + [EventTypes.Create, ""], + ] + response_body, from_token = self.do_sync( + sync_body, since=from_token, tok=user1_tok + ) + + # We should not see any state updates + self.assertIsNone(response_body["rooms"][room_id1].get("required_state")) + + # Send a message so the room comes down sync. + self.helper.send(room_id1, "msg", tok=user1_tok) + + # Update the sliding sync requests to include the room name again + sync_body["lists"]["foo-list"]["required_state"] = [ + [EventTypes.Create, ""], + [EventTypes.Name, ""], + ] + response_body, from_token = self.do_sync( + sync_body, since=from_token, tok=user1_tok + ) + + # We should not see the room name again, as we have already sent that + # down. + state_map = self.get_success( + self.storage_controllers.state.get_current_state(room_id1) + ) + # Send a message so the room comes down sync. + self.assertIsNone(response_body["rooms"][room_id1].get("required_state"))