Fix tests

This commit is contained in:
Erik Johnston 2016-08-26 10:15:52 +01:00
parent 50943ab942
commit 3f11953fcb

View file

@ -67,6 +67,8 @@ class StateGroupStore(object):
self._event_to_state_group = {} self._event_to_state_group = {}
self._group_to_state = {} self._group_to_state = {}
self._event_id_to_event = {}
self._next_group = 1 self._next_group = 1
def get_state_groups_ids(self, room_id, event_ids): def get_state_groups_ids(self, room_id, event_ids):
@ -96,6 +98,16 @@ class StateGroupStore(object):
self._event_to_state_group[event.event_id] = state_group self._event_to_state_group[event.event_id] = state_group
def get_events(self, event_ids, **kwargs):
return {
e_id: self._event_id_to_event[e_id] for e_id in event_ids
if e_id in self._event_id_to_event
}
def register_events(self, events):
for e in events:
self._event_id_to_event[e.event_id] = e
class DictObj(dict): class DictObj(dict):
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -138,6 +150,7 @@ class StateTestCase(unittest.TestCase):
spec_set=[ spec_set=[
"get_state_groups_ids", "get_state_groups_ids",
"add_event_hashes", "add_event_hashes",
"get_events",
] ]
) )
hs = Mock(spec_set=[ hs = Mock(spec_set=[
@ -240,6 +253,8 @@ class StateTestCase(unittest.TestCase):
store = StateGroupStore() store = StateGroupStore()
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
self.store.get_events = store.get_events
store.register_events(graph.walk())
context_store = {} context_store = {}
@ -250,7 +265,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual( self.assertSetEqual(
{"START", "A", "C"}, {"START", "A", "C"},
{e.event_id for e in context_store["D"].current_state.values()} {e_id for e_id in context_store["D"].current_state_ids.values()}
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -304,6 +319,8 @@ class StateTestCase(unittest.TestCase):
store = StateGroupStore() store = StateGroupStore()
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
self.store.get_events = store.get_events
store.register_events(graph.walk())
context_store = {} context_store = {}
@ -314,7 +331,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual( self.assertSetEqual(
{"START", "A", "B", "C"}, {"START", "A", "B", "C"},
{e.event_id for e in context_store["E"].current_state.values()} {e for e in context_store["E"].current_state_ids.values()}
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -385,6 +402,8 @@ class StateTestCase(unittest.TestCase):
store = StateGroupStore() store = StateGroupStore()
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
self.store.get_events = store.get_events
store.register_events(graph.walk())
context_store = {} context_store = {}
@ -395,7 +414,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual( self.assertSetEqual(
{"A1", "A2", "A3", "A5", "B"}, {"A1", "A2", "A3", "A5", "B"},
{e.event_id for e in context_store["D"].current_state.values()} {e for e in context_store["D"].current_state_ids.values()}
) )
def _add_depths(self, nodes, edges): def _add_depths(self, nodes, edges):
@ -522,6 +541,11 @@ class StateTestCase(unittest.TestCase):
create_event(type="test4", state_key=""), create_event(type="test4", state_key=""),
] ]
store = StateGroupStore()
store.register_events(old_state_1)
store.register_events(old_state_2)
self.store.get_events = store.get_events
context = yield self._get_context(event, old_state_1, old_state_2) context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(len(context.current_state_ids), 6) self.assertEqual(len(context.current_state_ids), 6)
@ -550,6 +574,11 @@ class StateTestCase(unittest.TestCase):
create_event(type="test4", state_key=""), create_event(type="test4", state_key=""),
] ]
store = StateGroupStore()
store.register_events(old_state_1)
store.register_events(old_state_2)
self.store.get_events = store.get_events
context = yield self._get_context(event, old_state_1, old_state_2) context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(len(context.current_state_ids), 6) self.assertEqual(len(context.current_state_ids), 6)
@ -585,9 +614,16 @@ class StateTestCase(unittest.TestCase):
create_event(type="test1", state_key="1", depth=2), create_event(type="test1", state_key="1", depth=2),
] ]
store = StateGroupStore()
store.register_events(old_state_1)
store.register_events(old_state_2)
self.store.get_events = store.get_events
context = yield self._get_context(event, old_state_1, old_state_2) context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(old_state_2[2].event.id, context.current_state_ids[("test1", "1")]) self.assertEqual(
old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
)
# Reverse the depth to make sure we are actually using the depths # Reverse the depth to make sure we are actually using the depths
# during state resolution. # during state resolution.
@ -604,9 +640,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test1", state_key="1", depth=1), create_event(type="test1", state_key="1", depth=1),
] ]
store.register_events(old_state_1)
store.register_events(old_state_2)
context = yield self._get_context(event, old_state_1, old_state_2) context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(old_state_1[2].event_id, context.current_state_ids[("test1", "1")]) self.assertEqual(
old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
)
def _get_context(self, event, old_state_1, old_state_2): def _get_context(self, event, old_state_1, old_state_2):
group_name_1 = "group_name_1" group_name_1 = "group_name_1"