Merge pull request #706 from matrix-org/markjh/slaveIV

Add tests for redactions
This commit is contained in:
Mark Haines 2016-04-07 17:30:37 +01:00
commit da84fa3d74
4 changed files with 54 additions and 5 deletions

View file

@ -69,6 +69,7 @@ class SlavedEventStore(BaseSlavedStore):
"_get_current_state_for_key"
]
get_event = DataStore.get_event.__func__
get_current_state = DataStore.get_current_state.__func__
get_current_state_for_key = DataStore.get_current_state_for_key.__func__
get_rooms_for_user_where_membership_is = (
@ -103,7 +104,7 @@ class SlavedEventStore(BaseSlavedStore):
def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token()
result["backfilled"] = self._backfill_id_gen.get_current_token()
result["backfill"] = self._backfill_id_gen.get_current_token()
return result
def process_replication(self, result):
@ -145,7 +146,6 @@ class SlavedEventStore(BaseSlavedStore):
position = row[0]
internal = json.loads(row[1])
event_json = json.loads(row[2])
event = FrozenEvent(event_json, internal_metadata_dict=internal)
self._invalidate_caches_for_event(
event, backfilled, reset_state=position in state_resets

View file

@ -112,7 +112,7 @@ class StreamIdGenerator(object):
self._current + self._step * (n + 1),
self._step
)
self._current += n
self._current += n * self._step
for next_id in next_ids:
self._unfinished_ids.append(next_id)

View file

@ -51,7 +51,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
def check(self, method, args, expected_result=None):
master_result = yield getattr(self.master_store, method)(*args)
slaved_result = yield getattr(self.slaved_store, method)(*args)
self.assertEqual(master_result, slaved_result)
if expected_result is not None:
self.assertEqual(master_result, expected_result)
self.assertEqual(slaved_result, expected_result)
self.assertEqual(master_result, slaved_result)

View file

@ -205,13 +205,59 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
[join3]
)
@defer.inlineCallbacks
def test_redactions(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.member", key=USER_ID, membership="join")
msg = yield self.persist(
type="m.room.message", msgtype="m.text", body="Hello"
)
yield self.replicate()
yield self.check("get_event", [msg.event_id], msg)
redaction = yield self.persist(
type="m.room.redaction", redacts=msg.event_id
)
yield self.replicate()
msg_dict = msg.get_dict()
msg_dict["content"] = {}
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
msg_dict["unsigned"]["redacted_because"] = redaction
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
yield self.check("get_event", [msg.event_id], redacted)
@defer.inlineCallbacks
def test_backfilled_redactions(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.member", key=USER_ID, membership="join")
msg = yield self.persist(
type="m.room.message", msgtype="m.text", body="Hello"
)
yield self.replicate()
yield self.check("get_event", [msg.event_id], msg)
redaction = yield self.persist(
type="m.room.redaction", redacts=msg.event_id, backfill=True
)
yield self.replicate()
msg_dict = msg.get_dict()
msg_dict["content"] = {}
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
msg_dict["unsigned"]["redacted_because"] = redaction
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
yield self.check("get_event", [msg.event_id], redacted)
event_id = 0
@defer.inlineCallbacks
def persist(
self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={},
state=None, reset_state=False, backfill=False,
depth=None, prev_events=[], auth_events=[], prev_state=[],
depth=None, prev_events=[], auth_events=[], prev_state=[], redacts=None,
**content
):
"""
@ -236,6 +282,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
event_dict["state_key"] = key
event_dict["prev_state"] = prev_state
if redacts is not None:
event_dict["redacts"] = redacts
event = FrozenEvent(event_dict, internal_metadata_dict=internal)
self.event_id += 1