state_after WIP

This commit is contained in:
Hugh Nimmo-Smith 2024-10-22 15:25:42 +01:00
parent 96425d4071
commit d9058a9182
3 changed files with 79 additions and 22 deletions

View file

@ -171,6 +171,7 @@ class JoinedSyncResult:
room_id: str
timeline: TimelineBatch
state: StateMap[EventBase]
state_after: StateMap[EventBase]
ephemeral: List[JsonDict]
account_data: List[JsonDict]
unread_notifications: JsonDict
@ -194,6 +195,7 @@ class ArchivedSyncResult:
room_id: str
timeline: TimelineBatch
state: StateMap[EventBase]
state_after: StateMap[EventBase]
account_data: List[JsonDict]
def __bool__(self) -> bool:
@ -1141,7 +1143,7 @@ class SyncHandler:
since_token: Optional[StreamToken],
end_token: StreamToken,
full_state: bool,
) -> MutableStateMap[EventBase]:
) -> Tuple[MutableStateMap[EventBase], MutableStateMap[EventBase]]:
"""Works out the difference in state between the end of the previous sync and
the start of the timeline.
@ -1157,7 +1159,7 @@ class SyncHandler:
`lazy_load_members` still applies when `full_state` is `True`.
Returns:
The state to return in the sync response for the room.
The `state` and `state_after` to return in the sync response for the room.
Clients will overlay this onto the state at the end of the previous sync to
arrive at the state at the start of the timeline.
@ -1224,11 +1226,15 @@ class SyncHandler:
# sync's timeline and the start of the current sync's timeline.
# See the docstring above for details.
state_ids: StateMap[str]
state_after_ids: StateMap[str]
# We need to know whether the state we fetch may be partial, so check
# whether the room is partial stated *before* fetching it.
is_partial_state_room = await self.store.is_partial_state_room(room_id)
if full_state:
state_ids = await self._compute_state_delta_for_full_sync(
[
state_ids,
state_after_ids,
] = await self._compute_state_delta_for_full_sync(
room_id,
sync_config.user,
batch,
@ -1242,7 +1248,10 @@ class SyncHandler:
# is indeed the case.
assert since_token is not None
state_ids = await self._compute_state_delta_for_incremental_sync(
[
state_ids,
state_after_ids,
] = await self._compute_state_delta_for_incremental_sync(
room_id,
batch,
since_token,
@ -1258,6 +1267,7 @@ class SyncHandler:
assert members_to_fetch is not None
assert first_event_by_sender_map is not None
# TODO: would this need to take account of state_after_ids?
additional_state_ids = (
await self._find_missing_partial_state_memberships(
room_id, members_to_fetch, first_event_by_sender_map, state_ids
@ -1304,14 +1314,26 @@ class SyncHandler:
state: Dict[str, EventBase] = {}
if state_ids:
state = await self.store.get_events(list(state_ids.values()))
state_after: Dict[str, EventBase] = {}
if state_after_ids:
state_after = await self.store.get_events(list(state_after_ids.values()))
return {
return [
{
(e.type, e.state_key): e
for e in await sync_config.filter_collection.filter_room_state(
list(state.values())
)
if e.type != EventTypes.Aliases # until MSC2261 or alternative solution
}
},
{
(e.type, e.state_key): e
for e in await sync_config.filter_collection.filter_room_state(
list(state_after.values())
)
if e.type != EventTypes.Aliases # until MSC2261 or alternative solution
},
]
async def _compute_state_delta_for_full_sync(
self,
@ -1321,7 +1343,7 @@ class SyncHandler:
end_token: StreamToken,
members_to_fetch: Optional[Set[str]],
timeline_state: StateMap[str],
) -> StateMap[str]:
) -> Tuple[StateMap[str], StateMap[str]]:
"""Calculate the state events to be included in a full sync response.
As with `_compute_state_delta_for_incremental_sync`, the result will include
@ -1341,7 +1363,7 @@ class SyncHandler:
Returns:
A map from (type, state_key) to event_id, for each event that we believe
should be included in the `state` part of the sync response.
should be included in the `state` and `state_after` part of the sync response.
"""
if members_to_fetch is not None:
# Lazy-loading of membership events is enabled.
@ -1410,7 +1432,7 @@ class SyncHandler:
end_token: StreamToken,
members_to_fetch: Optional[Set[str]],
timeline_state: StateMap[str],
) -> StateMap[str]:
) -> Tuple[StateMap[str], StateMap[str]]:
"""Calculate the state events to be included in an incremental sync response.
If lazy-loading of membership events is enabled (as indicated by
@ -1433,7 +1455,7 @@ class SyncHandler:
Returns:
A map from (type, state_key) to event_id, for each event that we believe
should be included in the `state` part of the sync response.
should be included in the `state` and `state_after` part of the sync response.
"""
if members_to_fetch is not None:
# Lazy-loading is enabled. Only return the state that is needed.
@ -1491,7 +1513,7 @@ class SyncHandler:
await_full_state=False,
)
)
return state_ids
return [state_ids, {}]
if batch:
state_at_timeline_start = (
@ -2860,7 +2882,7 @@ class SyncHandler:
return
if not room_builder.out_of_band:
state = await self.compute_state_delta(
[state, state_after] = await self.compute_state_delta(
room_id,
batch,
sync_config,
@ -2871,6 +2893,7 @@ class SyncHandler:
else:
# An out of band room won't have any state changes.
state = {}
state_after = {}
summary: Optional[JsonDict] = {}
@ -2905,6 +2928,7 @@ class SyncHandler:
room_id=room_id,
timeline=batch,
state=state,
state_after=state_after,
ephemeral=ephemeral,
account_data=account_data_events,
unread_notifications=unread_notifications,
@ -2957,6 +2981,7 @@ class SyncHandler:
room_id=room_id,
timeline=batch,
state=state,
state_after=state_after,
account_data=account_data_events,
)
if archived_room_sync or always_include:
@ -2982,8 +3007,8 @@ def _calculate_state(
timeline_end: StateMap[str],
previous_timeline_end: StateMap[str],
lazy_load_members: bool,
) -> StateMap[str]:
"""Works out what state to include in a sync response.
) -> Tuple[StateMap[str], StateMap[str]]:
"""Works out what state and state_after to include in a sync response.
Args:
timeline_contains: state in the timeline
@ -3080,13 +3105,18 @@ def _calculate_state(
# even try; it is ether omitted or plonked into `state` as if it were at the start
# of the timeline, depending on what else is in the timeline.)
state_ids = (
state_before_ids = (
(timeline_end_ids | timeline_start_ids)
- previous_timeline_end_ids
- timeline_contains_ids
)
return {event_id_to_state_key[e]: e for e in state_ids}
state_after_ids = timeline_end_ids - timeline_contains_ids - timeline_start_ids
return [
{event_id_to_state_key[e]: e for e in state_before_ids},
{event_id_to_state_key[e]: e for e in state_after_ids},
]
@attr.s(slots=True, auto_attribs=True)

View file

@ -521,9 +521,11 @@ class SyncRestServlet(RestServlet):
The room, encoded in our response format
"""
state_dict = room.state
state_after_dict = room.state_after
timeline_events = room.timeline.events
state_events = state_dict.values()
state_after_events = state_after_dict.values()
for event in itertools.chain(state_events, timeline_events):
# We've had bug reports that events were coming down under the
@ -545,6 +547,9 @@ class SyncRestServlet(RestServlet):
config=serialize_options,
bundle_aggregations=room.timeline.bundled_aggregations,
)
serialized_state_after = await self._event_serializer.serialize_events(
state_after_events, time_now, config=serialize_options
)
account_data = room.account_data
@ -555,6 +560,7 @@ class SyncRestServlet(RestServlet):
"limited": room.timeline.limited,
},
"state": {"events": serialized_state},
"state_after": {"events": serialized_state_after},
"account_data": {"events": account_data},
}

View file

@ -571,6 +571,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.state.values()],
[],
)
self.assertEqual(room_sync.state_after, {})
# Now send another event that points to S2, but not E3.
with self._patch_get_latest_events([s2_event]):
@ -602,6 +603,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.state.values()],
[s2_event],
)
self.assertEqual(room_sync.state_after, {})
def test_state_includes_changes_on_ungappy_syncs(self) -> None:
"""Test `state` where the sync is not gappy.
@ -776,6 +778,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.timeline.events],
[unrelated_state_event, s1_event],
)
self.assertEqual(room_sync.state_after, {})
# Send S2 -> S1
s2_event = self.helper.send_state(
@ -800,6 +803,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.timeline.events],
[s2_event],
)
self.assertEqual(room_sync.state_after, {})
# Send two regular events on different branches:
# E3 -> S1
@ -835,6 +839,12 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
e4_event,
], # We have two events from different timelines neither of which are state events
)
self.assertEqual(
[e.event_id for e in room_sync.state_after.values()],
[
s2_event
], # S2 is repeated because it is the state at the end of the the timeline (after E4)
)
# Send E5 which resolves the branches
e5_event = self.helper.send(room_id, "E5", tok=alice_tok)["event_id"]
@ -857,7 +867,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.timeline.events],
[e5_event],
)
# Problem: S2 is the winning state event but the last state event the client saw was S1.
self.assertEqual(room_sync.state_after, {})
# FIXED: S2 is the winning state event and the last state event that the client saw!
def test_state_after_on_branches_winner_at_start_of_timeline(self) -> None:
r"""Test `state` and `state_after` where not all information is in `state` + `timeline`.
@ -922,6 +934,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.timeline.events],
[unrelated_state_event, s1_event],
)
self.assertEqual(room_sync.state_after, {})
# Send S2 -> S1
s2_event = self.helper.send_state(
@ -946,6 +959,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.timeline.events],
[s2_event],
)
self.assertEqual(room_sync.state_after, {})
# Send two events on different branches:
# S3 -> S1
@ -978,6 +992,12 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
e4_event,
], # We have two events from different timelines
)
self.assertEqual(
[e.event_id for e in room_sync.state_after.values()],
[
s2_event
], # S2 is repeated because it is the state at the end of the the timeline (after E4)
)
# Send E5 which resolves the branches with S3 winning
e5_event = self.helper.send(room_id, "E5", tok=alice_tok)["event_id"]
@ -1003,6 +1023,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.timeline.events],
[e5_event],
)
self.assertEqual(room_sync.state_after, {})
@parameterized.expand(
[