state_after WIP

This commit is contained in:
Hugh Nimmo-Smith 2024-10-22 15:25:42 +01:00
parent 96425d4071
commit d9058a9182
3 changed files with 79 additions and 22 deletions

View file

@ -171,6 +171,7 @@ class JoinedSyncResult:
room_id: str room_id: str
timeline: TimelineBatch timeline: TimelineBatch
state: StateMap[EventBase] state: StateMap[EventBase]
state_after: StateMap[EventBase]
ephemeral: List[JsonDict] ephemeral: List[JsonDict]
account_data: List[JsonDict] account_data: List[JsonDict]
unread_notifications: JsonDict unread_notifications: JsonDict
@ -194,6 +195,7 @@ class ArchivedSyncResult:
room_id: str room_id: str
timeline: TimelineBatch timeline: TimelineBatch
state: StateMap[EventBase] state: StateMap[EventBase]
state_after: StateMap[EventBase]
account_data: List[JsonDict] account_data: List[JsonDict]
def __bool__(self) -> bool: def __bool__(self) -> bool:
@ -1141,7 +1143,7 @@ class SyncHandler:
since_token: Optional[StreamToken], since_token: Optional[StreamToken],
end_token: StreamToken, end_token: StreamToken,
full_state: bool, full_state: bool,
) -> MutableStateMap[EventBase]: ) -> Tuple[MutableStateMap[EventBase], MutableStateMap[EventBase]]:
"""Works out the difference in state between the end of the previous sync and """Works out the difference in state between the end of the previous sync and
the start of the timeline. the start of the timeline.
@ -1157,7 +1159,7 @@ class SyncHandler:
`lazy_load_members` still applies when `full_state` is `True`. `lazy_load_members` still applies when `full_state` is `True`.
Returns: Returns:
The state to return in the sync response for the room. The `state` and `state_after` to return in the sync response for the room.
Clients will overlay this onto the state at the end of the previous sync to Clients will overlay this onto the state at the end of the previous sync to
arrive at the state at the start of the timeline. arrive at the state at the start of the timeline.
@ -1224,11 +1226,15 @@ class SyncHandler:
# sync's timeline and the start of the current sync's timeline. # sync's timeline and the start of the current sync's timeline.
# See the docstring above for details. # See the docstring above for details.
state_ids: StateMap[str] state_ids: StateMap[str]
state_after_ids: StateMap[str]
# We need to know whether the state we fetch may be partial, so check # We need to know whether the state we fetch may be partial, so check
# whether the room is partial stated *before* fetching it. # whether the room is partial stated *before* fetching it.
is_partial_state_room = await self.store.is_partial_state_room(room_id) is_partial_state_room = await self.store.is_partial_state_room(room_id)
if full_state: if full_state:
state_ids = await self._compute_state_delta_for_full_sync( [
state_ids,
state_after_ids,
] = await self._compute_state_delta_for_full_sync(
room_id, room_id,
sync_config.user, sync_config.user,
batch, batch,
@ -1242,7 +1248,10 @@ class SyncHandler:
# is indeed the case. # is indeed the case.
assert since_token is not None assert since_token is not None
state_ids = await self._compute_state_delta_for_incremental_sync( [
state_ids,
state_after_ids,
] = await self._compute_state_delta_for_incremental_sync(
room_id, room_id,
batch, batch,
since_token, since_token,
@ -1258,6 +1267,7 @@ class SyncHandler:
assert members_to_fetch is not None assert members_to_fetch is not None
assert first_event_by_sender_map is not None assert first_event_by_sender_map is not None
# TODO: would this need to take account of state_after_ids?
additional_state_ids = ( additional_state_ids = (
await self._find_missing_partial_state_memberships( await self._find_missing_partial_state_memberships(
room_id, members_to_fetch, first_event_by_sender_map, state_ids room_id, members_to_fetch, first_event_by_sender_map, state_ids
@ -1304,14 +1314,26 @@ class SyncHandler:
state: Dict[str, EventBase] = {} state: Dict[str, EventBase] = {}
if state_ids: if state_ids:
state = await self.store.get_events(list(state_ids.values())) state = await self.store.get_events(list(state_ids.values()))
state_after: Dict[str, EventBase] = {}
if state_after_ids:
state_after = await self.store.get_events(list(state_after_ids.values()))
return { return [
{
(e.type, e.state_key): e (e.type, e.state_key): e
for e in await sync_config.filter_collection.filter_room_state( for e in await sync_config.filter_collection.filter_room_state(
list(state.values()) list(state.values())
) )
if e.type != EventTypes.Aliases # until MSC2261 or alternative solution if e.type != EventTypes.Aliases # until MSC2261 or alternative solution
} },
{
(e.type, e.state_key): e
for e in await sync_config.filter_collection.filter_room_state(
list(state_after.values())
)
if e.type != EventTypes.Aliases # until MSC2261 or alternative solution
},
]
async def _compute_state_delta_for_full_sync( async def _compute_state_delta_for_full_sync(
self, self,
@ -1321,7 +1343,7 @@ class SyncHandler:
end_token: StreamToken, end_token: StreamToken,
members_to_fetch: Optional[Set[str]], members_to_fetch: Optional[Set[str]],
timeline_state: StateMap[str], timeline_state: StateMap[str],
) -> StateMap[str]: ) -> Tuple[StateMap[str], StateMap[str]]:
"""Calculate the state events to be included in a full sync response. """Calculate the state events to be included in a full sync response.
As with `_compute_state_delta_for_incremental_sync`, the result will include As with `_compute_state_delta_for_incremental_sync`, the result will include
@ -1341,7 +1363,7 @@ class SyncHandler:
Returns: Returns:
A map from (type, state_key) to event_id, for each event that we believe A map from (type, state_key) to event_id, for each event that we believe
should be included in the `state` part of the sync response. should be included in the `state` and `state_after` part of the sync response.
""" """
if members_to_fetch is not None: if members_to_fetch is not None:
# Lazy-loading of membership events is enabled. # Lazy-loading of membership events is enabled.
@ -1410,7 +1432,7 @@ class SyncHandler:
end_token: StreamToken, end_token: StreamToken,
members_to_fetch: Optional[Set[str]], members_to_fetch: Optional[Set[str]],
timeline_state: StateMap[str], timeline_state: StateMap[str],
) -> StateMap[str]: ) -> Tuple[StateMap[str], StateMap[str]]:
"""Calculate the state events to be included in an incremental sync response. """Calculate the state events to be included in an incremental sync response.
If lazy-loading of membership events is enabled (as indicated by If lazy-loading of membership events is enabled (as indicated by
@ -1433,7 +1455,7 @@ class SyncHandler:
Returns: Returns:
A map from (type, state_key) to event_id, for each event that we believe A map from (type, state_key) to event_id, for each event that we believe
should be included in the `state` part of the sync response. should be included in the `state` and `state_after` part of the sync response.
""" """
if members_to_fetch is not None: if members_to_fetch is not None:
# Lazy-loading is enabled. Only return the state that is needed. # Lazy-loading is enabled. Only return the state that is needed.
@ -1491,7 +1513,7 @@ class SyncHandler:
await_full_state=False, await_full_state=False,
) )
) )
return state_ids return [state_ids, {}]
if batch: if batch:
state_at_timeline_start = ( state_at_timeline_start = (
@ -2860,7 +2882,7 @@ class SyncHandler:
return return
if not room_builder.out_of_band: if not room_builder.out_of_band:
state = await self.compute_state_delta( [state, state_after] = await self.compute_state_delta(
room_id, room_id,
batch, batch,
sync_config, sync_config,
@ -2871,6 +2893,7 @@ class SyncHandler:
else: else:
# An out of band room won't have any state changes. # An out of band room won't have any state changes.
state = {} state = {}
state_after = {}
summary: Optional[JsonDict] = {} summary: Optional[JsonDict] = {}
@ -2905,6 +2928,7 @@ class SyncHandler:
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=state, state=state,
state_after=state_after,
ephemeral=ephemeral, ephemeral=ephemeral,
account_data=account_data_events, account_data=account_data_events,
unread_notifications=unread_notifications, unread_notifications=unread_notifications,
@ -2957,6 +2981,7 @@ class SyncHandler:
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=state, state=state,
state_after=state_after,
account_data=account_data_events, account_data=account_data_events,
) )
if archived_room_sync or always_include: if archived_room_sync or always_include:
@ -2982,8 +3007,8 @@ def _calculate_state(
timeline_end: StateMap[str], timeline_end: StateMap[str],
previous_timeline_end: StateMap[str], previous_timeline_end: StateMap[str],
lazy_load_members: bool, lazy_load_members: bool,
) -> StateMap[str]: ) -> Tuple[StateMap[str], StateMap[str]]:
"""Works out what state to include in a sync response. """Works out what state and state_after to include in a sync response.
Args: Args:
timeline_contains: state in the timeline timeline_contains: state in the timeline
@ -3080,13 +3105,18 @@ def _calculate_state(
# even try; it is ether omitted or plonked into `state` as if it were at the start # even try; it is ether omitted or plonked into `state` as if it were at the start
# of the timeline, depending on what else is in the timeline.) # of the timeline, depending on what else is in the timeline.)
state_ids = ( state_before_ids = (
(timeline_end_ids | timeline_start_ids) (timeline_end_ids | timeline_start_ids)
- previous_timeline_end_ids - previous_timeline_end_ids
- timeline_contains_ids - timeline_contains_ids
) )
return {event_id_to_state_key[e]: e for e in state_ids} state_after_ids = timeline_end_ids - timeline_contains_ids - timeline_start_ids
return [
{event_id_to_state_key[e]: e for e in state_before_ids},
{event_id_to_state_key[e]: e for e in state_after_ids},
]
@attr.s(slots=True, auto_attribs=True) @attr.s(slots=True, auto_attribs=True)

View file

@ -521,9 +521,11 @@ class SyncRestServlet(RestServlet):
The room, encoded in our response format The room, encoded in our response format
""" """
state_dict = room.state state_dict = room.state
state_after_dict = room.state_after
timeline_events = room.timeline.events timeline_events = room.timeline.events
state_events = state_dict.values() state_events = state_dict.values()
state_after_events = state_after_dict.values()
for event in itertools.chain(state_events, timeline_events): for event in itertools.chain(state_events, timeline_events):
# We've had bug reports that events were coming down under the # We've had bug reports that events were coming down under the
@ -545,6 +547,9 @@ class SyncRestServlet(RestServlet):
config=serialize_options, config=serialize_options,
bundle_aggregations=room.timeline.bundled_aggregations, bundle_aggregations=room.timeline.bundled_aggregations,
) )
serialized_state_after = await self._event_serializer.serialize_events(
state_after_events, time_now, config=serialize_options
)
account_data = room.account_data account_data = room.account_data
@ -555,6 +560,7 @@ class SyncRestServlet(RestServlet):
"limited": room.timeline.limited, "limited": room.timeline.limited,
}, },
"state": {"events": serialized_state}, "state": {"events": serialized_state},
"state_after": {"events": serialized_state_after},
"account_data": {"events": account_data}, "account_data": {"events": account_data},
} }

View file

@ -571,6 +571,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.state.values()], [e.event_id for e in room_sync.state.values()],
[], [],
) )
self.assertEqual(room_sync.state_after, {})
# Now send another event that points to S2, but not E3. # Now send another event that points to S2, but not E3.
with self._patch_get_latest_events([s2_event]): with self._patch_get_latest_events([s2_event]):
@ -602,6 +603,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.state.values()], [e.event_id for e in room_sync.state.values()],
[s2_event], [s2_event],
) )
self.assertEqual(room_sync.state_after, {})
def test_state_includes_changes_on_ungappy_syncs(self) -> None: def test_state_includes_changes_on_ungappy_syncs(self) -> None:
"""Test `state` where the sync is not gappy. """Test `state` where the sync is not gappy.
@ -776,6 +778,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.timeline.events], [e.event_id for e in room_sync.timeline.events],
[unrelated_state_event, s1_event], [unrelated_state_event, s1_event],
) )
self.assertEqual(room_sync.state_after, {})
# Send S2 -> S1 # Send S2 -> S1
s2_event = self.helper.send_state( s2_event = self.helper.send_state(
@ -800,6 +803,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.timeline.events], [e.event_id for e in room_sync.timeline.events],
[s2_event], [s2_event],
) )
self.assertEqual(room_sync.state_after, {})
# Send two regular events on different branches: # Send two regular events on different branches:
# E3 -> S1 # E3 -> S1
@ -835,6 +839,12 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
e4_event, e4_event,
], # We have two events from different timelines neither of which are state events ], # We have two events from different timelines neither of which are state events
) )
self.assertEqual(
[e.event_id for e in room_sync.state_after.values()],
[
s2_event
], # S2 is repeated because it is the state at the end of the the timeline (after E4)
)
# Send E5 which resolves the branches # Send E5 which resolves the branches
e5_event = self.helper.send(room_id, "E5", tok=alice_tok)["event_id"] e5_event = self.helper.send(room_id, "E5", tok=alice_tok)["event_id"]
@ -857,7 +867,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.timeline.events], [e.event_id for e in room_sync.timeline.events],
[e5_event], [e5_event],
) )
# Problem: S2 is the winning state event but the last state event the client saw was S1. self.assertEqual(room_sync.state_after, {})
# FIXED: S2 is the winning state event and the last state event that the client saw!
def test_state_after_on_branches_winner_at_start_of_timeline(self) -> None: def test_state_after_on_branches_winner_at_start_of_timeline(self) -> None:
r"""Test `state` and `state_after` where not all information is in `state` + `timeline`. r"""Test `state` and `state_after` where not all information is in `state` + `timeline`.
@ -922,6 +934,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.timeline.events], [e.event_id for e in room_sync.timeline.events],
[unrelated_state_event, s1_event], [unrelated_state_event, s1_event],
) )
self.assertEqual(room_sync.state_after, {})
# Send S2 -> S1 # Send S2 -> S1
s2_event = self.helper.send_state( s2_event = self.helper.send_state(
@ -946,6 +959,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.timeline.events], [e.event_id for e in room_sync.timeline.events],
[s2_event], [s2_event],
) )
self.assertEqual(room_sync.state_after, {})
# Send two events on different branches: # Send two events on different branches:
# S3 -> S1 # S3 -> S1
@ -978,6 +992,12 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
e4_event, e4_event,
], # We have two events from different timelines ], # We have two events from different timelines
) )
self.assertEqual(
[e.event_id for e in room_sync.state_after.values()],
[
s2_event
], # S2 is repeated because it is the state at the end of the the timeline (after E4)
)
# Send E5 which resolves the branches with S3 winning # Send E5 which resolves the branches with S3 winning
e5_event = self.helper.send(room_id, "E5", tok=alice_tok)["event_id"] e5_event = self.helper.send(room_id, "E5", tok=alice_tok)["event_id"]
@ -1003,6 +1023,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.timeline.events], [e.event_id for e in room_sync.timeline.events],
[e5_event], [e5_event],
) )
self.assertEqual(room_sync.state_after, {})
@parameterized.expand( @parameterized.expand(
[ [