diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 408d375439..e224af8dd8 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -124,7 +124,7 @@ async def resolve_events_with_store( # Also fetch all auth events that appear in only some of the state sets' # auth chains. auth_diff = await _get_auth_chain_difference( - room_id, state_sets, event_map, state_res_store, clock + room_id, state_sets, event_map, state_res_store ) with Measure(clock, "rei_state_res:rews2_b"): # TODO temporary (rei) @@ -284,7 +284,6 @@ async def _get_auth_chain_difference( state_sets: Sequence[StateMap[str]], unpersisted_events: Dict[str, EventBase], state_res_store: StateResolutionStore, - clock: Clock, ) -> Set[str]: """Compare the auth chains of each state set and return the set of events that only appear in some, but not all of the auth chains. @@ -316,82 +315,77 @@ async def _get_auth_chain_difference( # event IDs if they appear in the `unpersisted_events`. This is the intersection of # the event's auth chain with the events in `unpersisted_events` *plus* their # auth event IDs. - with Measure(clock, "rei_state_res:rews2_a1"): # TODO temporary (rei) - events_to_auth_chain: Dict[str, Set[str]] = {} - for event in unpersisted_events.values(): - chain = {event.event_id} - events_to_auth_chain[event.event_id] = chain + events_to_auth_chain: Dict[str, Set[str]] = {} + for event in unpersisted_events.values(): + chain = {event.event_id} + events_to_auth_chain[event.event_id] = chain - to_search = [event] - while to_search: - for auth_id in to_search.pop().auth_event_ids(): - chain.add(auth_id) - auth_event = unpersisted_events.get(auth_id) - if auth_event: - to_search.append(auth_event) + to_search = [event] + while to_search: + for auth_id in to_search.pop().auth_event_ids(): + chain.add(auth_id) + auth_event = unpersisted_events.get(auth_id) + if auth_event: + to_search.append(auth_event) # We now 1) calculate the auth chain difference for the unpersisted events # and 2) work out the state sets to pass to the store. # # Note: If there are no `unpersisted_events` (which is the common case), we can do a # much simpler calculation. - with Measure(clock, "rei_state_res:rews2_a2"): # TODO temporary (rei) - if unpersisted_events: - # The list of state sets to pass to the store, where each state set is a set - # of the event ids making up the state. This is similar to `state_sets`, - # except that (a) we only have event ids, not the complete - # ((type, state_key)->event_id) mappings; and (b) we have stripped out - # unpersisted events and replaced them with the persisted events in - # their auth chain. - state_sets_ids: List[Set[str]] = [] + if unpersisted_events: + # The list of state sets to pass to the store, where each state set is a set + # of the event ids making up the state. This is similar to `state_sets`, + # except that (a) we only have event ids, not the complete + # ((type, state_key)->event_id) mappings; and (b) we have stripped out + # unpersisted events and replaced them with the persisted events in + # their auth chain. + state_sets_ids: List[Set[str]] = [] - # For each state set, the unpersisted event IDs reachable (by their auth - # chain) from the events in that set. - unpersisted_set_ids: List[Set[str]] = [] + # For each state set, the unpersisted event IDs reachable (by their auth + # chain) from the events in that set. + unpersisted_set_ids: List[Set[str]] = [] - for state_set in state_sets: - set_ids: Set[str] = set() - state_sets_ids.append(set_ids) + for state_set in state_sets: + set_ids: Set[str] = set() + state_sets_ids.append(set_ids) - unpersisted_ids: Set[str] = set() - unpersisted_set_ids.append(unpersisted_ids) + unpersisted_ids: Set[str] = set() + unpersisted_set_ids.append(unpersisted_ids) - for event_id in state_set.values(): - event_chain = events_to_auth_chain.get(event_id) - if event_chain is not None: - # We have an unpersisted event. We add all the auth - # events that it references which are also unpersisted. - set_ids.update( - e for e in event_chain if e not in unpersisted_events - ) + for event_id in state_set.values(): + event_chain = events_to_auth_chain.get(event_id) + if event_chain is not None: + # We have an unpersisted event. We add all the auth + # events that it references which are also unpersisted. + set_ids.update( + e for e in event_chain if e not in unpersisted_events + ) - # We also add the full chain of unpersisted event IDs - # referenced by this state set, so that we can work out the - # auth chain difference of the unpersisted events. - unpersisted_ids.update( - e for e in event_chain if e in unpersisted_events - ) - else: - set_ids.add(event_id) + # We also add the full chain of unpersisted event IDs + # referenced by this state set, so that we can work out the + # auth chain difference of the unpersisted events. + unpersisted_ids.update( + e for e in event_chain if e in unpersisted_events + ) + else: + set_ids.add(event_id) - # The auth chain difference of the unpersisted events of the state sets - # is calculated by taking the difference between the union and - # intersections. - union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:]) - intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:]) + # The auth chain difference of the unpersisted events of the state sets + # is calculated by taking the difference between the union and + # intersections. + union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:]) + intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:]) - auth_difference_unpersisted_part: StrCollection = union - intersection - else: - auth_difference_unpersisted_part = () - state_sets_ids = [set(state_set.values()) for state_set in state_sets] + auth_difference_unpersisted_part: StrCollection = union - intersection + else: + auth_difference_unpersisted_part = () + state_sets_ids = [set(state_set.values()) for state_set in state_sets] - with Measure(clock, "rei_state_res:rews2_a3"): # TODO temporary (rei) - difference = await state_res_store.get_auth_chain_difference( - room_id, state_sets_ids - ) - - with Measure(clock, "rei_state_res:rews2_a4"): # TODO temporary (rei) - difference.update(auth_difference_unpersisted_part) + difference = await state_res_store.get_auth_chain_difference( + room_id, state_sets_ids + ) + difference.update(auth_difference_unpersisted_part) return difference