Correctly changes to required state config

Fixes #17698
This commit is contained in:
Erik Johnston 2024-10-01 16:53:18 +01:00
parent d4e3ad04cd
commit af9c72db9d
4 changed files with 475 additions and 6 deletions

View file

@ -14,7 +14,7 @@
import logging
from itertools import chain
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple
from typing import TYPE_CHECKING, AbstractSet, Dict, List, Mapping, Optional, Set, Tuple
from prometheus_client import Histogram
from typing_extensions import assert_never
@ -988,6 +988,10 @@ class SlidingSyncHandler:
include_others=required_state_filter.include_others,
)
# The required state map to store in the room sync config, if it has
# changed.
changed_required_state_map: Optional[Mapping[str, AbstractSet[str]]] = None
# We can return all of the state that was requested if this was the first
# time we've sent the room down this connection.
room_state: StateMap[EventBase] = {}
@ -1001,6 +1005,29 @@ class SlidingSyncHandler:
else:
assert from_bound is not None
if prev_room_sync_config is not None:
# Check if there are any changes to the required state config
# that we need to handle.
changed_required_state_map, added_state_filter = (
_required_state_changes(
user.to_string(),
prev_room_sync_config,
room_sync_config,
room_state_delta_id_map,
)
)
if added_state_filter:
# Some state entries got added, so we pull out the current
# state for them. If we don't do this we'd only send down new deltas.
state_ids = await self.get_current_state_ids_at(
room_id=room_id,
room_membership_for_user_at_to_token=room_membership_for_user_at_to_token,
state_filter=added_state_filter,
to_token=to_token,
)
room_state_delta_id_map.update(state_ids)
events = await self.store.get_events(
state_filter.filter_state(room_state_delta_id_map).values()
)
@ -1083,6 +1110,10 @@ class SlidingSyncHandler:
prev_room_sync_config = previous_connection_state.room_configs.get(room_id)
# Record the `room_sync_config` if we're `ignore_timeline_bound` (which means
# that the `timeline_limit` has increased)
room_sync_required_state_map_to_persist = room_sync_config.required_state_map
if changed_required_state_map:
room_sync_required_state_map_to_persist = changed_required_state_map
if ignore_timeline_bound:
# FIXME: We signal the fact that we're sending down more events to
# the client by setting `unstable_expanded_timeline` to true (see
@ -1091,7 +1122,7 @@ class SlidingSyncHandler:
new_connection_state.room_configs[room_id] = RoomSyncConfig(
timeline_limit=room_sync_config.timeline_limit,
required_state_map=room_sync_config.required_state_map,
required_state_map=room_sync_required_state_map_to_persist,
)
elif prev_room_sync_config is not None:
# If the result is `limited` then we need to record that the
@ -1120,10 +1151,14 @@ class SlidingSyncHandler:
):
new_connection_state.room_configs[room_id] = RoomSyncConfig(
timeline_limit=room_sync_config.timeline_limit,
required_state_map=room_sync_config.required_state_map,
required_state_map=room_sync_required_state_map_to_persist,
)
# TODO: Record changes in required_state.
elif changed_required_state_map is not None:
new_connection_state.room_configs[room_id] = RoomSyncConfig(
timeline_limit=room_sync_config.timeline_limit,
required_state_map=room_sync_required_state_map_to_persist,
)
else:
new_connection_state.room_configs[room_id] = room_sync_config
@ -1242,3 +1277,167 @@ class SlidingSyncHandler:
return new_bump_event_pos.stream
return None
def _required_state_changes(
user_id: str,
previous_room_config: "RoomSyncConfig",
room_sync_config: RoomSyncConfig,
state_deltas: StateMap[str],
) -> Tuple[Optional[Mapping[str, AbstractSet[str]]], StateFilter]:
"""Calculates the changes between the required state room config from the
previous requests compared with the current request.
This does two things. First, it calculates if we need to update the room
config due to changes to required state. Secondly, it works out which state
entries we need to pull from current state and return due to the state entry
now appearing in the required state when it previously wasn't (on top of the
state deltas).
This function tries to ensure to handle the case where a state entry is
added, removed and then added again to the required state. In that case we
only want to re-send that entry down sync if it has changed.
Returns:
A 2-tuple of updated required state config and the state filter to use
to fetch extra current state that we need to return.
"""
prev_required_state_map = previous_room_config.required_state_map
request_required_state_map = room_sync_config.required_state_map
if prev_required_state_map == request_required_state_map:
# There has been no change. Return immediately.
return None, StateFilter.none()
prev_wildcard = prev_required_state_map.get(StateValues.WILDCARD, set())
request_wildcard = request_required_state_map.get(StateValues.WILDCARD, set())
# If a event type wildcard has been added or removed we don't try and do
# anything fancy, and instead always update the effective room required
# state config to match the request.
if request_wildcard - prev_wildcard:
# Some keys were added, so we need to fetch everything
return request_required_state_map, StateFilter.all()
if prev_wildcard - request_wildcard:
# Keys were only removed, so we don't have to fetch everything.
return request_required_state_map, StateFilter.none()
# Contains updates to the required state map compared with the previous room
# config. This has the same format as `RoomSyncConfig.required_state`
changes: Dict[str, AbstractSet[str]] = {}
# The set of types/state keys that we need to fetch and return to the
# client. Passed to `StateFilter.from_types(...)`
added: List[Tuple[str, Optional[str]]] = []
# First we calculate what, if anything, has been *added*.
for event_type in (
prev_required_state_map.keys() | request_required_state_map.keys()
):
old_state_keys = prev_required_state_map.get(event_type, set())
request_state_keys = request_required_state_map.get(event_type, set())
if old_state_keys == request_state_keys:
# No change to this type
continue
if not request_state_keys - old_state_keys:
# Nothing *added*, so we skip. Removals happen below.
continue
# Always update changes to include the newly added keys
changes[event_type] = request_state_keys
# Record the new state keys to fetch for this type.
if StateValues.WILDCARD in request_state_keys:
# If we have added a wildcard then we always just fetch everything.
added.append((event_type, None))
else:
for state_key in request_state_keys - old_state_keys:
if state_key == StateValues.ME:
added.append((event_type, user_id))
elif state_key == StateValues.LAZY:
# We handle lazy loading separately, so don't need to
# explicitly add anything here.
pass
else:
added.append((event_type, state_key))
added_state_filter = StateFilter.from_types(added)
# Convert the list of state deltas to map from type to state_keys that have
# changed.
changed_types_to_state_keys: Dict[str, Set[str]] = {}
for event_type, state_key in state_deltas:
changed_types_to_state_keys.setdefault(event_type, set()).add(state_key)
# Figure out what changes we need to apply to the effective required state
# config.
for event_type, changed_state_keys in changed_types_to_state_keys.items():
old_state_keys = prev_required_state_map.get(event_type, set())
request_state_keys = request_required_state_map.get(event_type, set())
if old_state_keys == request_state_keys:
# No change.
continue
if request_state_keys - old_state_keys:
# We've expanded the set of state keys, so we just clobber the
# current set with the new set.
#
# We could also ensure that we keep entries where the state hasn't
# changed, but are no longer in the requested required state, but
# that's a sufficient edge case that we can ignore (as its only a
# performance optimization).
changes[event_type] = request_state_keys
continue
old_state_key_wildcard = StateValues.WILDCARD in old_state_keys
request_state_key_wildcard = StateValues.WILDCARD in request_state_keys
if old_state_key_wildcard != request_state_key_wildcard:
# If a wildcard has been added or removed we always update the
# required state when any state with the same type has changed.
changes[event_type] = request_state_keys
continue
old_state_key_lazy = StateValues.LAZY in old_state_keys
request_state_key_lazy = StateValues.LAZY in request_state_keys
if old_state_key_lazy != request_state_key_lazy:
# If a "$LAZY" has been added or removed we always update the
# required state for simplicity.
changes[event_type] = request_state_keys
continue
# Handle "$ME" values by replacing state keys that match the user ID.
if user_id in changed_state_keys:
changed_state_keys.discard(user_id)
changed_state_keys.add(StateValues.ME)
# At this point there are no wildcards and no additions to the set of
# state keys requested, only deletions.
#
# We only remove state keys from the effective state if they've been
# removed from the request *and* the state has changed. This ensures
# that if a client removes and then readds a state key, we only send
# down the associated current state event if its changed (rather than
# sending down the same event twice).
invalidated = (old_state_keys - request_state_keys) & changed_state_keys
if invalidated:
changes[event_type] = old_state_keys - invalidated
if changes:
# Update the required state config based on the changes.
new_required_state_map = dict(prev_required_state_map)
for event_type, state_keys in changes.items():
if state_keys:
new_required_state_map[event_type] = state_keys
else:
# Remove entries with empty state keys.
new_required_state_map.pop(event_type, None)
return new_required_state_map, added_state_filter
else:
return None, added_state_filter

View file

@ -386,8 +386,8 @@ class SlidingSyncStore(SQLBaseStore):
required_state_map: Dict[int, Dict[str, Set[str]]] = {}
for row in rows:
state = required_state_map[row[0]] = {}
for event_type, state_keys in db_to_json(row[1]):
state[event_type] = set(state_keys)
for event_type, state_key in db_to_json(row[1]):
state.setdefault(event_type, set()).add(state_key)
# Get all the room configs, looking up the required state from the map
# above.

View file

@ -616,6 +616,11 @@ class StateFilter:
return False
def __bool__(self) -> bool:
if self.include_others:
return True
return bool(self.types)
_ALL_STATE_FILTER = StateFilter(types=immutabledict(), include_others=True)
_ALL_NON_MEMBER_STATE_FILTER = StateFilter(

View file

@ -862,3 +862,268 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
exact=True,
message=f"Expected only fully-stated rooms to show up for test_key={list_key}.",
)
def test_rooms_required_state_expand(self) -> None:
"""Test that when we expand the required state argument we get the
expanded state, and not just the changes to the new expanded."""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Create a room with a room name.
room_id1 = self.helper.create_room_as(
user1_id, tok=user1_tok, extra_content={"name": "Foo"}
)
# Only request the state event to begin with
sync_body = {
"lists": {
"foo-list": {
"ranges": [[0, 1]],
"required_state": [
[EventTypes.Create, ""],
],
"timeline_limit": 1,
}
}
}
response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
state_map = self.get_success(
self.storage_controllers.state.get_current_state(room_id1)
)
self._assertRequiredStateIncludes(
response_body["rooms"][room_id1]["required_state"],
{
state_map[(EventTypes.Create, "")],
},
exact=True,
)
# Send a message so the room comes down sync.
self.helper.send(room_id1, "msg", tok=user1_tok)
# Update the sliding sync requests to include the room name
sync_body["lists"]["foo-list"]["required_state"] = [
[EventTypes.Create, ""],
[EventTypes.Name, ""],
]
response_body, from_token = self.do_sync(
sync_body, since=from_token, tok=user1_tok
)
# We should see the room name, even though there haven't been any
# changes.
self._assertRequiredStateIncludes(
response_body["rooms"][room_id1]["required_state"],
{
state_map[(EventTypes.Name, "")],
},
exact=True,
)
# Send a message so the room comes down sync.
self.helper.send(room_id1, "msg", tok=user1_tok)
# We should not see any state changes.
response_body, from_token = self.do_sync(
sync_body, since=from_token, tok=user1_tok
)
self.assertIsNone(response_body["rooms"][room_id1].get("required_state"))
def test_rooms_required_state_expand_retract_expand(self) -> None:
"""Test that when expanding, retracting and then expanding the required
state, we get the changes that happened."""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Create a room with a room name.
room_id1 = self.helper.create_room_as(
user1_id, tok=user1_tok, extra_content={"name": "Foo"}
)
# Only request the state event to begin with
sync_body = {
"lists": {
"foo-list": {
"ranges": [[0, 1]],
"required_state": [
[EventTypes.Create, ""],
],
"timeline_limit": 1,
}
}
}
response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
state_map = self.get_success(
self.storage_controllers.state.get_current_state(room_id1)
)
self._assertRequiredStateIncludes(
response_body["rooms"][room_id1]["required_state"],
{
state_map[(EventTypes.Create, "")],
},
exact=True,
)
# Send a message so the room comes down sync.
self.helper.send(room_id1, "msg", tok=user1_tok)
# Update the sliding sync requests to include the room name
sync_body["lists"]["foo-list"]["required_state"] = [
[EventTypes.Create, ""],
[EventTypes.Name, ""],
]
response_body, from_token = self.do_sync(
sync_body, since=from_token, tok=user1_tok
)
# We should see the room name, even though there haven't been any
# changes.
self._assertRequiredStateIncludes(
response_body["rooms"][room_id1]["required_state"],
{
state_map[(EventTypes.Name, "")],
},
exact=True,
)
# Update the room name
self.helper.send_state(
room_id1, "m.room.name", {"name": "Bar"}, state_key="", tok=user1_tok
)
# Update the sliding sync requests to exclude the room name again
sync_body["lists"]["foo-list"]["required_state"] = [
[EventTypes.Create, ""],
]
response_body, from_token = self.do_sync(
sync_body, since=from_token, tok=user1_tok
)
# We should not see the updated room name in state (though it will be in
# the timeline).
self.assertIsNone(response_body["rooms"][room_id1].get("required_state"))
# Send a message so the room comes down sync.
self.helper.send(room_id1, "msg", tok=user1_tok)
# Update the sliding sync requests to include the room name again
sync_body["lists"]["foo-list"]["required_state"] = [
[EventTypes.Create, ""],
[EventTypes.Name, ""],
]
response_body, from_token = self.do_sync(
sync_body, since=from_token, tok=user1_tok
)
# We should see the *new* room name, even though there haven't been any
# changes.
state_map = self.get_success(
self.storage_controllers.state.get_current_state(room_id1)
)
self._assertRequiredStateIncludes(
response_body["rooms"][room_id1]["required_state"],
{
state_map[(EventTypes.Name, "")],
},
exact=True,
)
def test_rooms_required_state_expand_deduplicate(self) -> None:
"""Test that when expanding, retracting and then expanding the required
state, we don't get the state down again if it hasn't changed"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Create a room with a room name.
room_id1 = self.helper.create_room_as(
user1_id, tok=user1_tok, extra_content={"name": "Foo"}
)
# Only request the state event to begin with
sync_body = {
"lists": {
"foo-list": {
"ranges": [[0, 1]],
"required_state": [
[EventTypes.Create, ""],
],
"timeline_limit": 1,
}
}
}
response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
state_map = self.get_success(
self.storage_controllers.state.get_current_state(room_id1)
)
self._assertRequiredStateIncludes(
response_body["rooms"][room_id1]["required_state"],
{
state_map[(EventTypes.Create, "")],
},
exact=True,
)
# Send a message so the room comes down sync.
self.helper.send(room_id1, "msg", tok=user1_tok)
# Update the sliding sync requests to include the room name
sync_body["lists"]["foo-list"]["required_state"] = [
[EventTypes.Create, ""],
[EventTypes.Name, ""],
]
response_body, from_token = self.do_sync(
sync_body, since=from_token, tok=user1_tok
)
# We should see the room name, even though there haven't been any
# changes.
self._assertRequiredStateIncludes(
response_body["rooms"][room_id1]["required_state"],
{
state_map[(EventTypes.Name, "")],
},
exact=True,
)
# Send a message so the room comes down sync.
self.helper.send(room_id1, "msg", tok=user1_tok)
# Update the sliding sync requests to exclude the room name again
sync_body["lists"]["foo-list"]["required_state"] = [
[EventTypes.Create, ""],
]
response_body, from_token = self.do_sync(
sync_body, since=from_token, tok=user1_tok
)
# We should not see any state updates
self.assertIsNone(response_body["rooms"][room_id1].get("required_state"))
# Send a message so the room comes down sync.
self.helper.send(room_id1, "msg", tok=user1_tok)
# Update the sliding sync requests to include the room name again
sync_body["lists"]["foo-list"]["required_state"] = [
[EventTypes.Create, ""],
[EventTypes.Name, ""],
]
response_body, from_token = self.do_sync(
sync_body, since=from_token, tok=user1_tok
)
# We should not see the room name again, as we have already sent that
# down.
state_map = self.get_success(
self.storage_controllers.state.get_current_state(room_id1)
)
# Send a message so the room comes down sync.
self.assertIsNone(response_body["rooms"][room_id1].get("required_state"))