Fix MSC4222 returning full state (#17915)

There was a bug that meant we would return the full state of the room on
incremental syncs when using lazy loaded members and there were no
entries in the timeline.

This was due to trying to use `state_filter or state_filter.all()` as a
short hand for handling `None` case, however `state_filter` implements
`__bool__` so if the state filter was empty it would be set to full.

c.f. MSC4222 and #17888
This commit is contained in:
Erik Johnston 2024-11-08 16:41:24 +00:00 committed by GitHub
parent c7a1d0aa1a
commit cacd4fd7bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 91 additions and 31 deletions

1
changelog.d/17915.bugfix Normal file
View file

@ -0,0 +1 @@
Fix experimental support for [MSC4222](https://github.com/matrix-org/matrix-spec-proposals/pull/4222) where we would return the full state on incremental syncs when using lazy loaded members and there were no new events in the timeline.

View file

@ -196,7 +196,9 @@ class MessageHandler:
AuthError (403) if the user doesn't have permission to view AuthError (403) if the user doesn't have permission to view
members of this room. members of this room.
""" """
state_filter = state_filter or StateFilter.all() if state_filter is None:
state_filter = StateFilter.all()
user_id = requester.user.to_string() user_id = requester.user.to_string()
if at_token: if at_token:

View file

@ -1520,7 +1520,7 @@ class SyncHandler:
if sync_config.use_state_after: if sync_config.use_state_after:
delta_state_ids: MutableStateMap[str] = {} delta_state_ids: MutableStateMap[str] = {}
if members_to_fetch is not None: if members_to_fetch:
# We're lazy-loading, so the client might need some more member # We're lazy-loading, so the client might need some more member
# events to understand the events in this timeline. So we always # events to understand the events in this timeline. So we always
# fish out all the member events corresponding to the timeline # fish out all the member events corresponding to the timeline

View file

@ -234,8 +234,11 @@ class StateStorageController:
RuntimeError if we don't have a state group for one or more of the events RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown) (ie they are outliers or unknown)
""" """
if state_filter is None:
state_filter = StateFilter.all()
await_full_state = True await_full_state = True
if state_filter and not state_filter.must_await_full_state(self._is_mine_id): if not state_filter.must_await_full_state(self._is_mine_id):
await_full_state = False await_full_state = False
event_to_groups = await self.get_state_group_for_events( event_to_groups = await self.get_state_group_for_events(
@ -244,7 +247,7 @@ class StateStorageController:
groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups( group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter or StateFilter.all() groups, state_filter
) )
state_event_map = await self.stores.main.get_events( state_event_map = await self.stores.main.get_events(
@ -292,10 +295,11 @@ class StateStorageController:
RuntimeError if we don't have a state group for one or more of the events RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown) (ie they are outliers or unknown)
""" """
if ( if state_filter is None:
await_full_state state_filter = StateFilter.all()
and state_filter
and not state_filter.must_await_full_state(self._is_mine_id) if await_full_state and not state_filter.must_await_full_state(
self._is_mine_id
): ):
# Full state is not required if the state filter is restrictive enough. # Full state is not required if the state filter is restrictive enough.
await_full_state = False await_full_state = False
@ -306,7 +310,7 @@ class StateStorageController:
groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups( group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter or StateFilter.all() groups, state_filter
) )
event_to_state = { event_to_state = {
@ -335,9 +339,10 @@ class StateStorageController:
RuntimeError if we don't have a state group for the event (ie it is an RuntimeError if we don't have a state group for the event (ie it is an
outlier or is unknown) outlier or is unknown)
""" """
state_map = await self.get_state_for_events( if state_filter is None:
[event_id], state_filter or StateFilter.all() state_filter = StateFilter.all()
)
state_map = await self.get_state_for_events([event_id], state_filter)
return state_map[event_id] return state_map[event_id]
@trace @trace
@ -365,9 +370,12 @@ class StateStorageController:
RuntimeError if we don't have a state group for the event (ie it is an RuntimeError if we don't have a state group for the event (ie it is an
outlier or is unknown) outlier or is unknown)
""" """
if state_filter is None:
state_filter = StateFilter.all()
state_map = await self.get_state_ids_for_events( state_map = await self.get_state_ids_for_events(
[event_id], [event_id],
state_filter or StateFilter.all(), state_filter,
await_full_state=await_full_state, await_full_state=await_full_state,
) )
return state_map[event_id] return state_map[event_id]
@ -388,9 +396,12 @@ class StateStorageController:
at the event and `state_filter` is not satisfied by partial state. at the event and `state_filter` is not satisfied by partial state.
Defaults to `True`. Defaults to `True`.
""" """
if state_filter is None:
state_filter = StateFilter.all()
state_ids = await self.get_state_ids_for_event( state_ids = await self.get_state_ids_for_event(
event_id, event_id,
state_filter=state_filter or StateFilter.all(), state_filter=state_filter,
await_full_state=await_full_state, await_full_state=await_full_state,
) )
@ -426,6 +437,9 @@ class StateStorageController:
at the last event in the room before `stream_position` and at the last event in the room before `stream_position` and
`state_filter` is not satisfied by partial state. Defaults to `True`. `state_filter` is not satisfied by partial state. Defaults to `True`.
""" """
if state_filter is None:
state_filter = StateFilter.all()
# FIXME: This gets the state at the latest event before the stream ordering, # FIXME: This gets the state at the latest event before the stream ordering,
# which might not be the same as the "current state" of the room at the time # which might not be the same as the "current state" of the room at the time
# of the stream token if there were multiple forward extremities at the time. # of the stream token if there were multiple forward extremities at the time.
@ -442,7 +456,7 @@ class StateStorageController:
if last_event_id: if last_event_id:
state = await self.get_state_after_event( state = await self.get_state_after_event(
last_event_id, last_event_id,
state_filter=state_filter or StateFilter.all(), state_filter=state_filter,
await_full_state=await_full_state, await_full_state=await_full_state,
) )
@ -500,9 +514,10 @@ class StateStorageController:
Returns: Returns:
Dict of state group to state map. Dict of state group to state map.
""" """
return await self.stores.state._get_state_for_groups( if state_filter is None:
groups, state_filter or StateFilter.all() state_filter = StateFilter.all()
)
return await self.stores.state._get_state_for_groups(groups, state_filter)
@trace @trace
@tag_args @tag_args
@ -583,12 +598,13 @@ class StateStorageController:
Returns: Returns:
The current state of the room. The current state of the room.
""" """
if await_full_state and ( if state_filter is None:
not state_filter or state_filter.must_await_full_state(self._is_mine_id) state_filter = StateFilter.all()
):
if await_full_state and state_filter.must_await_full_state(self._is_mine_id):
await self._partial_state_room_tracker.await_full_state(room_id) await self._partial_state_room_tracker.await_full_state(room_id)
if state_filter and not state_filter.is_full(): if state_filter is not None and not state_filter.is_full():
return await self.stores.main.get_partial_filtered_current_state_ids( return await self.stores.main.get_partial_filtered_current_state_ids(
room_id, state_filter room_id, state_filter
) )

View file

@ -572,10 +572,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns: Returns:
Map from type/state_key to event ID. Map from type/state_key to event ID.
""" """
if state_filter is None:
state_filter = StateFilter.all()
where_clause, where_args = ( where_clause, where_args = (state_filter).make_sql_filter_clause()
state_filter or StateFilter.all()
).make_sql_filter_clause()
if not where_clause: if not where_clause:
# We delegate to the cached version # We delegate to the cached version
@ -584,7 +584,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
def _get_filtered_current_state_ids_txn( def _get_filtered_current_state_ids_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> StateMap[str]: ) -> StateMap[str]:
results = StateMapWrapper(state_filter=state_filter or StateFilter.all()) results = StateMapWrapper(state_filter=state_filter)
sql = """ sql = """
SELECT type, state_key, event_id FROM current_state_events SELECT type, state_key, event_id FROM current_state_events

View file

@ -112,8 +112,8 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
Returns: Returns:
Map from state_group to a StateMap at that point. Map from state_group to a StateMap at that point.
""" """
if state_filter is None:
state_filter = state_filter or StateFilter.all() state_filter = StateFilter.all()
results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups} results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups}

View file

@ -284,7 +284,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
Returns: Returns:
Dict of state group to state map. Dict of state group to state map.
""" """
state_filter = state_filter or StateFilter.all() if state_filter is None:
state_filter = StateFilter.all()
member_filter, non_member_filter = state_filter.get_member_split() member_filter, non_member_filter = state_filter.get_member_split()

View file

@ -68,15 +68,23 @@ class StateFilter:
include_others: bool = False include_others: bool = False
def __attrs_post_init__(self) -> None: def __attrs_post_init__(self) -> None:
# If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary
if self.include_others: if self.include_others:
# If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary
# this is needed to work around the fact that StateFilter is frozen # this is needed to work around the fact that StateFilter is frozen
object.__setattr__( object.__setattr__(
self, self,
"types", "types",
immutabledict({k: v for k, v in self.types.items() if v is not None}), immutabledict({k: v for k, v in self.types.items() if v is not None}),
) )
else:
# Otherwise we remove entries where the value is the empty set.
object.__setattr__(
self,
"types",
immutabledict({k: v for k, v in self.types.items() if v is None or v}),
)
@staticmethod @staticmethod
def all() -> "StateFilter": def all() -> "StateFilter":

View file

@ -1262,3 +1262,35 @@ class SyncStateAfterTestCase(tests.unittest.HomeserverTestCase):
) )
) )
self.assertEqual(state[("m.test_event", "")], second_state["event_id"]) self.assertEqual(state[("m.test_event", "")], second_state["event_id"])
def test_incremental_sync_lazy_loaded_no_timeline(self) -> None:
"""Test that lazy-loading with an empty timeline doesn't return the full
state.
There was a bug where an empty state filter would cause the DB to return
the full state, rather than an empty set.
"""
user = self.register_user("user", "password")
tok = self.login("user", "password")
# Create a room as the user and set some custom state.
joined_room = self.helper.create_room_as(user, tok=tok)
since_token = self.hs.get_event_sources().get_current_token()
end_stream_token = self.hs.get_event_sources().get_current_token()
state = self.get_success(
self.sync_handler._compute_state_delta_for_incremental_sync(
room_id=joined_room,
sync_config=generate_sync_config(user, use_state_after=True),
batch=TimelineBatch(
prev_batch=end_stream_token, events=[], limited=True
),
since_token=since_token,
end_token=end_stream_token,
members_to_fetch=set(),
timeline_state={},
)
)
self.assertEqual(state, {})