Pass room_id to get_auth_chain_difference (#8879)

This is so that we can choose which algorithm to use based on the room ID.
This commit is contained in:
Erik Johnston 2020-12-04 15:52:49 +00:00 committed by GitHub
parent b774c555d8
commit df4b1e9c74
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 33 additions and 17 deletions

1
changelog.d/8879.misc Normal file
View file

@ -0,0 +1 @@
Pass `room_id` to `get_auth_chain_difference`.

View file

@ -783,7 +783,7 @@ class StateResolutionStore:
) )
def get_auth_chain_difference( def get_auth_chain_difference(
self, state_sets: List[Set[str]] self, room_id: str, state_sets: List[Set[str]]
) -> Awaitable[Set[str]]: ) -> Awaitable[Set[str]]:
"""Given sets of state events figure out the auth chain difference (as """Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm). per state res v2 algorithm).
@ -796,4 +796,4 @@ class StateResolutionStore:
An awaitable that resolves to a set of event IDs. An awaitable that resolves to a set of event IDs.
""" """
return self.store.get_auth_chain_difference(state_sets) return self.store.get_auth_chain_difference(room_id, state_sets)

View file

@ -97,7 +97,9 @@ async def resolve_events_with_store(
# Also fetch all auth events that appear in only some of the state sets' # Also fetch all auth events that appear in only some of the state sets'
# auth chains. # auth chains.
auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store) auth_diff = await _get_auth_chain_difference(
room_id, state_sets, event_map, state_res_store
)
full_conflicted_set = set( full_conflicted_set = set(
itertools.chain( itertools.chain(
@ -236,6 +238,7 @@ async def _get_power_level_for_sender(
async def _get_auth_chain_difference( async def _get_auth_chain_difference(
room_id: str,
state_sets: Sequence[StateMap[str]], state_sets: Sequence[StateMap[str]],
event_map: Dict[str, EventBase], event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore", state_res_store: "synapse.state.StateResolutionStore",
@ -332,7 +335,9 @@ async def _get_auth_chain_difference(
difference_from_event_map = () difference_from_event_map = ()
state_sets_ids = [set(state_set.values()) for state_set in state_sets] state_sets_ids = [set(state_set.values()) for state_set in state_sets]
difference = await state_res_store.get_auth_chain_difference(state_sets_ids) difference = await state_res_store.get_auth_chain_difference(
room_id, state_sets_ids
)
difference.update(difference_from_event_map) difference.update(difference_from_event_map)
return difference return difference

View file

@ -137,7 +137,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return list(results) return list(results)
async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]: async def get_auth_chain_difference(
self, room_id: str, state_sets: List[Set[str]]
) -> Set[str]:
"""Given sets of state events figure out the auth chain difference (as """Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm). per state res v2 algorithm).

View file

@ -623,7 +623,9 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
store = TestStateResolutionStore(persisted_events) store = TestStateResolutionStore(persisted_events)
diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store) diff_d = _get_auth_chain_difference(
ROOM_ID, state_sets, unpersited_events, store
)
difference = self.successResultOf(defer.ensureDeferred(diff_d)) difference = self.successResultOf(defer.ensureDeferred(diff_d))
self.assertEqual(difference, {c.event_id}) self.assertEqual(difference, {c.event_id})
@ -662,7 +664,9 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
store = TestStateResolutionStore(persisted_events) store = TestStateResolutionStore(persisted_events)
diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store) diff_d = _get_auth_chain_difference(
ROOM_ID, state_sets, unpersited_events, store
)
difference = self.successResultOf(defer.ensureDeferred(diff_d)) difference = self.successResultOf(defer.ensureDeferred(diff_d))
self.assertEqual(difference, {d.event_id, c.event_id}) self.assertEqual(difference, {d.event_id, c.event_id})
@ -707,7 +711,9 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
store = TestStateResolutionStore(persisted_events) store = TestStateResolutionStore(persisted_events)
diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store) diff_d = _get_auth_chain_difference(
ROOM_ID, state_sets, unpersited_events, store
)
difference = self.successResultOf(defer.ensureDeferred(diff_d)) difference = self.successResultOf(defer.ensureDeferred(diff_d))
self.assertEqual(difference, {d.event_id, e.event_id}) self.assertEqual(difference, {d.event_id, e.event_id})
@ -773,7 +779,7 @@ class TestStateResolutionStore:
return list(result) return list(result)
def get_auth_chain_difference(self, auth_sets): def get_auth_chain_difference(self, room_id, auth_sets):
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
common = set(chains[0]).intersection(*chains[1:]) common = set(chains[0]).intersection(*chains[1:])

View file

@ -202,39 +202,41 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# Now actually test that various combinations give the right result: # Now actually test that various combinations give the right result:
difference = self.get_success( difference = self.get_success(
self.store.get_auth_chain_difference([{"a"}, {"b"}]) self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
) )
self.assertSetEqual(difference, {"a", "b"}) self.assertSetEqual(difference, {"a", "b"})
difference = self.get_success( difference = self.get_success(
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}]) self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
) )
self.assertSetEqual(difference, {"a", "b", "c", "e", "f"}) self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
difference = self.get_success( difference = self.get_success(
self.store.get_auth_chain_difference([{"a", "c"}, {"b"}]) self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
) )
self.assertSetEqual(difference, {"a", "b", "c"}) self.assertSetEqual(difference, {"a", "b", "c"})
difference = self.get_success( difference = self.get_success(
self.store.get_auth_chain_difference([{"a", "c"}, {"b", "c"}]) self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
) )
self.assertSetEqual(difference, {"a", "b"}) self.assertSetEqual(difference, {"a", "b"})
difference = self.get_success( difference = self.get_success(
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}]) self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
) )
self.assertSetEqual(difference, {"a", "b", "d", "e"}) self.assertSetEqual(difference, {"a", "b", "d", "e"})
difference = self.get_success( difference = self.get_success(
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}]) self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
) )
self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"}) self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
difference = self.get_success( difference = self.get_success(
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}]) self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
) )
self.assertSetEqual(difference, {"a", "b"}) self.assertSetEqual(difference, {"a", "b"})
difference = self.get_success(self.store.get_auth_chain_difference([{"a"}])) difference = self.get_success(
self.store.get_auth_chain_difference(room_id, [{"a"}])
)
self.assertSetEqual(difference, set()) self.assertSetEqual(difference, set())