diff --git a/tests/test_state.py b/tests/test_state.py index df9362c985..de2d35145a 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -67,6 +67,8 @@ class StateGroupStore(object): self._event_to_state_group = {} self._group_to_state = {} + self._event_id_to_event = {} + self._next_group = 1 def get_state_groups_ids(self, room_id, event_ids): @@ -96,6 +98,16 @@ class StateGroupStore(object): self._event_to_state_group[event.event_id] = state_group + def get_events(self, event_ids, **kwargs): + return { + e_id: self._event_id_to_event[e_id] for e_id in event_ids + if e_id in self._event_id_to_event + } + + def register_events(self, events): + for e in events: + self._event_id_to_event[e.event_id] = e + class DictObj(dict): def __init__(self, **kwargs): @@ -138,6 +150,7 @@ class StateTestCase(unittest.TestCase): spec_set=[ "get_state_groups_ids", "add_event_hashes", + "get_events", ] ) hs = Mock(spec_set=[ @@ -240,6 +253,8 @@ class StateTestCase(unittest.TestCase): store = StateGroupStore() self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids + self.store.get_events = store.get_events + store.register_events(graph.walk()) context_store = {} @@ -250,7 +265,7 @@ class StateTestCase(unittest.TestCase): self.assertSetEqual( {"START", "A", "C"}, - {e.event_id for e in context_store["D"].current_state.values()} + {e_id for e_id in context_store["D"].current_state_ids.values()} ) @defer.inlineCallbacks @@ -304,6 +319,8 @@ class StateTestCase(unittest.TestCase): store = StateGroupStore() self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids + self.store.get_events = store.get_events + store.register_events(graph.walk()) context_store = {} @@ -314,7 +331,7 @@ class StateTestCase(unittest.TestCase): self.assertSetEqual( {"START", "A", "B", "C"}, - {e.event_id for e in context_store["E"].current_state.values()} + {e for e in context_store["E"].current_state_ids.values()} ) @defer.inlineCallbacks @@ -385,6 +402,8 @@ class StateTestCase(unittest.TestCase): store = StateGroupStore() self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids + self.store.get_events = store.get_events + store.register_events(graph.walk()) context_store = {} @@ -395,7 +414,7 @@ class StateTestCase(unittest.TestCase): self.assertSetEqual( {"A1", "A2", "A3", "A5", "B"}, - {e.event_id for e in context_store["D"].current_state.values()} + {e for e in context_store["D"].current_state_ids.values()} ) def _add_depths(self, nodes, edges): @@ -522,6 +541,11 @@ class StateTestCase(unittest.TestCase): create_event(type="test4", state_key=""), ] + store = StateGroupStore() + store.register_events(old_state_1) + store.register_events(old_state_2) + self.store.get_events = store.get_events + context = yield self._get_context(event, old_state_1, old_state_2) self.assertEqual(len(context.current_state_ids), 6) @@ -550,6 +574,11 @@ class StateTestCase(unittest.TestCase): create_event(type="test4", state_key=""), ] + store = StateGroupStore() + store.register_events(old_state_1) + store.register_events(old_state_2) + self.store.get_events = store.get_events + context = yield self._get_context(event, old_state_1, old_state_2) self.assertEqual(len(context.current_state_ids), 6) @@ -585,9 +614,16 @@ class StateTestCase(unittest.TestCase): create_event(type="test1", state_key="1", depth=2), ] + store = StateGroupStore() + store.register_events(old_state_1) + store.register_events(old_state_2) + self.store.get_events = store.get_events + context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(old_state_2[2].event.id, context.current_state_ids[("test1", "1")]) + self.assertEqual( + old_state_2[2].event_id, context.current_state_ids[("test1", "1")] + ) # Reverse the depth to make sure we are actually using the depths # during state resolution. @@ -604,9 +640,14 @@ class StateTestCase(unittest.TestCase): create_event(type="test1", state_key="1", depth=1), ] + store.register_events(old_state_1) + store.register_events(old_state_2) + context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(old_state_1[2].event_id, context.current_state_ids[("test1", "1")]) + self.assertEqual( + old_state_1[2].event_id, context.current_state_ids[("test1", "1")] + ) def _get_context(self, event, old_state_1, old_state_2): group_name_1 = "group_name_1"