diff --git a/synapse/api/events/__init__.py b/synapse/api/events/__init__.py index 168b812311..fc3f350570 100644 --- a/synapse/api/events/__init__.py +++ b/synapse/api/events/__init__.py @@ -60,6 +60,7 @@ class SynapseEvent(JsonEncodedObject): "age_ts", "prev_content", "prev_state", + "replaces_state", "redacted_because", "origin_server_ts", ] diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 6e897e915d..164363cdc5 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -147,10 +147,7 @@ class DirectoryHandler(BaseHandler): content={"aliases": aliases}, ) - snapshot = yield self.store.snapshot_room( - room_id=room_id, - user_id=user_id, - ) + snapshot = yield self.store.snapshot_room(event) yield self._on_new_room_event( event, snapshot, extra_users=[user_id], suppress_auth=True diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 1464a60937..513ec9a5e3 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -313,9 +313,7 @@ class FederationHandler(BaseHandler): state_key=user_id, ) - snapshot = yield self.store.snapshot_room( - event.room_id, event.user_id, - ) + snapshot = yield self.store.snapshot_room(event) snapshot.fill_out_prev_events(event) yield self.state_handler.annotate_state_groups(event) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index c6f6ab14d1..8394013df3 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -81,7 +81,7 @@ class MessageHandler(BaseHandler): user = self.hs.parse_userid(event.user_id) assert user.is_mine, "User must be our own: %s" % (user,) - snapshot = yield self.store.snapshot_room(event.room_id, event.user_id) + snapshot = yield self.store.snapshot_room(event) yield self._on_new_room_event( event, snapshot, suppress_auth=suppress_auth @@ -141,12 +141,7 @@ class MessageHandler(BaseHandler): SynapseError if something went wrong. """ - snapshot = yield self.store.snapshot_room( - event.room_id, - event.user_id, - state_type=event.type, - state_key=event.state_key, - ) + snapshot = yield self.store.snapshot_room(event) yield self._on_new_room_event(event, snapshot) @@ -214,7 +209,7 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def send_feedback(self, event): - snapshot = yield self.store.snapshot_room(event.room_id, event.user_id) + snapshot = yield self.store.snapshot_room(event) # store message in db yield self._on_new_room_event(event, snapshot) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 4cd0a06093..e47814483a 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -17,7 +17,6 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.constants import Membership -from synapse.api.events.room import RoomMemberEvent from ._base import BaseHandler @@ -196,10 +195,7 @@ class ProfileHandler(BaseHandler): ) for j in joins: - snapshot = yield self.store.snapshot_room( - j.room_id, j.state_key, RoomMemberEvent.TYPE, - j.state_key - ) + snapshot = yield self.store.snapshot_room(j) content = { "membership": j.content["membership"], diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index f176ad39bf..55c893eb58 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -122,10 +122,7 @@ class RoomCreationHandler(BaseHandler): @defer.inlineCallbacks def handle_event(event): - snapshot = yield self.store.snapshot_room( - room_id=room_id, - user_id=user_id, - ) + snapshot = yield self.store.snapshot_room(event) logger.debug("Event: %s", event) @@ -364,10 +361,8 @@ class RoomMemberHandler(BaseHandler): """ target_user_id = event.state_key - snapshot = yield self.store.snapshot_room( - event.room_id, event.user_id, - RoomMemberEvent.TYPE, target_user_id - ) + snapshot = yield self.store.snapshot_room(event) + ## TODO(markjh): get prev state from snapshot. prev_state = yield self.store.get_room_member( target_user_id, event.room_id @@ -442,10 +437,7 @@ class RoomMemberHandler(BaseHandler): content=content, ) - snapshot = yield self.store.snapshot_room( - room_id, joinee.to_string(), RoomMemberEvent.TYPE, - joinee.to_string() - ) + snapshot = yield self.store.snapshot_room(new_event) yield self._do_join(new_event, snapshot, room_host=host, do_auth=True) diff --git a/synapse/rest/room.py b/synapse/rest/room.py index ec0ce78fda..997895dab0 100644 --- a/synapse/rest/room.py +++ b/synapse/rest/room.py @@ -138,7 +138,7 @@ class RoomStateEventRestServlet(RestServlet): raise SynapseError( 404, "Event not found.", errcode=Codes.NOT_FOUND ) - defer.returnValue((200, data[0].get_dict()["content"])) + defer.returnValue((200, data.get_dict()["content"])) @defer.inlineCallbacks def on_PUT(self, request, room_id, event_type, state_key): diff --git a/synapse/state.py b/synapse/state.py index 32744e047c..97a8160a33 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -45,40 +45,6 @@ class StateHandler(object): self.server_name = hs.hostname self.hs = hs - @defer.inlineCallbacks - @log_function - def handle_new_event(self, event, snapshot): - """ Given an event this works out if a) we have sufficient power level - to update the state and b) works out what the prev_state should be. - - Returns: - Deferred: Resolved with a boolean indicating if we successfully - updated the state. - - Raised: - AuthError - """ - # This needs to be done in a transaction. - - if not hasattr(event, "state_key"): - return - - # Now I need to fill out the prev state and work out if it has auth - # (w.r.t. to power levels) - - snapshot.fill_out_prev_events(event) - yield self.annotate_state_groups(event) - - if event.old_state_events: - current_state = event.old_state_events.get( - (event.type, event.state_key) - ) - - if current_state: - event.prev_state = current_state.event_id - - defer.returnValue(True) - @defer.inlineCallbacks @log_function def annotate_state_groups(self, event, old_state=None): @@ -111,7 +77,10 @@ class StateHandler(object): event.old_state_events = copy.deepcopy(new_state) if hasattr(event, "state_key"): - new_state[(event.type, event.state_key)] = event + key = (event.type, event.state_key) + if key in new_state: + event.replaces_state = new_state[key].event_id + new_state[key] = event event.state_group = None event.state_events = new_state diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 6b8fed4502..2d62fc2ed0 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -242,8 +242,8 @@ class DataStore(RoomMemberStore, RoomStore, "state_key": event.state_key, } - if hasattr(event, "prev_state"): - vals["prev_state"] = event.prev_state + if hasattr(event, "replaces_state"): + vals["prev_state"] = event.replaces_state self._simple_insert_txn(txn, "state_events", vals) @@ -258,6 +258,40 @@ class DataStore(RoomMemberStore, RoomStore, } ) + for e_id, h in event.prev_state: + self._simple_insert_txn( + txn, + table="event_edges", + values={ + "event_id": event.event_id, + "prev_event_id": e_id, + "room_id": event.room_id, + "is_state": 1, + }, + or_ignore=True, + ) + + if not backfilled: + self._simple_insert_txn( + txn, + table="state_forward_extremities", + values={ + "event_id": event.event_id, + "room_id": event.room_id, + "type": event.type, + "state_key": event.state_key, + } + ) + + for prev_state_id, _ in event.prev_state: + self._simple_delete_txn( + txn, + table="state_forward_extremities", + keyvalues={ + "event_id": prev_state_id, + } + ) + for hash_alg, hash_base64 in event.hashes.items(): hash_bytes = decode_base64(hash_base64) self._store_event_content_hash_txn( @@ -357,7 +391,7 @@ class DataStore(RoomMemberStore, RoomStore, ], ) - def snapshot_room(self, room_id, user_id, state_type=None, state_key=None): + def snapshot_room(self, event): """Snapshot the room for an update by a user Args: room_id (synapse.types.RoomId): The room to snapshot. @@ -368,16 +402,29 @@ class DataStore(RoomMemberStore, RoomStore, synapse.storage.Snapshot: A snapshot of the state of the room. """ def _snapshot(txn): - membership_state = self._get_room_member(txn, user_id, room_id) - prev_events = self._get_latest_events_in_room(txn, room_id) + prev_events = self._get_latest_events_in_room( + txn, + event.room_id + ) + + prev_state = None + state_key = None + if hasattr(event, "state_key"): + state_key = event.state_key + prev_state = self._get_latest_state_in_room( + txn, + event.room_id, + type=event.type, + state_key=state_key, + ) return Snapshot( store=self, - room_id=room_id, - user_id=user_id, + room_id=event.room_id, + user_id=event.user_id, prev_events=prev_events, - membership_state=membership_state, - state_type=state_type, + prev_state=prev_state, + state_type=event.type, state_key=state_key, ) @@ -400,30 +447,29 @@ class Snapshot(object): """ def __init__(self, store, room_id, user_id, prev_events, - membership_state, state_type=None, state_key=None, - prev_state_pdu=None): + prev_state, state_type=None, state_key=None): self.store = store self.room_id = room_id self.user_id = user_id self.prev_events = prev_events - self.membership_state = membership_state + self.prev_state = prev_state self.state_type = state_type self.state_key = state_key - self.prev_state_pdu = prev_state_pdu def fill_out_prev_events(self, event): - if hasattr(event, "prev_events"): - return + if not hasattr(event, "prev_events"): + event.prev_events = [ + (event_id, hashes) + for event_id, hashes, _ in self.prev_events + ] - event.prev_events = [ - (event_id, hashes) - for event_id, hashes, _ in self.prev_events - ] + if self.prev_events: + event.depth = max([int(v) for _, _, v in self.prev_events]) + 1 + else: + event.depth = 0 - if self.prev_events: - event.depth = max([int(v) for _, _, v in self.prev_events]) + 1 - else: - event.depth = 0 + if not hasattr(event, "prev_state") and self.prev_state is not None: + event.prev_state = self.prev_state def schema_path(schema): diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 7d445b4633..7821fc4726 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -245,7 +245,6 @@ class SQLBaseStore(object): return [r[0] for r in txn.fetchall()] - def _simple_select_onecol(self, table, keyvalues, retcol): """Executes a SELECT query on the named table, which returns a list comprising of the values of the named column from the selected rows. @@ -273,17 +272,30 @@ class SQLBaseStore(object): keyvalues : dict of column names and values to select the rows with retcols : list of strings giving the names of the columns to return """ + return self.runInteraction( + "_simple_select_list", + self._simple_select_list_txn, + table, keyvalues, retcols + ) + + def _simple_select_list_txn(self, txn, table, keyvalues, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + txn : Transaction object + table : string giving the table name + keyvalues : dict of column names and values to select the rows with + retcols : list of strings giving the names of the columns to return + """ sql = "SELECT %s FROM %s WHERE %s" % ( ", ".join(retcols), table, - " AND ".join("%s = ?" % (k) for k in keyvalues) + " AND ".join("%s = ?" % (k, ) for k in keyvalues) ) - def func(txn): - txn.execute(sql, keyvalues.values()) - return self.cursor_to_dict(txn) - - return self.runInteraction("_simple_select_list", func) + txn.execute(sql, keyvalues.values()) + return self.cursor_to_dict(txn) def _simple_update_one(self, table, keyvalues, updatevalues, retcols=None): @@ -417,6 +429,10 @@ class SQLBaseStore(object): d.pop("topological_ordering", None) d.pop("processed", None) d["origin_server_ts"] = d.pop("ts", 0) + replaces_state = d.pop("prev_state", None) + + if replaces_state: + d["replaces_state"] = replaces_state d.update(json.loads(row_dict["unrecognized_keys"])) d["content"] = json.loads(d["content"]) @@ -450,16 +466,32 @@ class SQLBaseStore(object): k: encode_base64(v) for k, v in signatures.items() } - ev.prev_events = self._get_prev_events(txn, ev.event_id) + prevs = self._get_prev_events_and_state(txn, ev.event_id) - if hasattr(ev, "prev_state"): - # Load previous state_content. - # TODO: Should we be pulling this out above? - cursor = txn.execute(select_event_sql, (ev.prev_state,)) - prevs = self.cursor_to_dict(cursor) - if prevs: - prev = self._parse_event_from_row(prevs[0]) - ev.prev_content = prev.content + ev.prev_events = [ + (e_id, h) + for e_id, h, is_state in prevs + if is_state == 0 + ] + + if hasattr(ev, "state_key"): + ev.prev_state = [ + (e_id, h) + for e_id, h, is_state in prevs + if is_state == 1 + ] + + if hasattr(ev, "replaces_state"): + # Load previous state_content. + # FIXME (erikj): Handle multiple prev_states. + cursor = txn.execute( + select_event_sql, + (ev.replaces_state,) + ) + prevs = self.cursor_to_dict(cursor) + if prevs: + prev = self._parse_event_from_row(prevs[0]) + ev.prev_content = prev.content if not hasattr(ev, "redacted"): logger.debug("Doesn't have redacted key: %s", ev) diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index f427aba879..180a764134 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -69,19 +69,21 @@ class EventFederationStore(SQLBaseStore): return results - def _get_prev_events(self, txn, event_id): - prev_ids = self._simple_select_onecol_txn( + def _get_latest_state_in_room(self, txn, room_id, type, state_key): + event_ids = self._simple_select_onecol_txn( txn, - table="event_edges", + table="state_forward_extremities", keyvalues={ - "event_id": event_id, + "room_id": room_id, + "type": type, + "state_key": state_key, }, - retcol="prev_event_id", + retcol="event_id", ) results = [] - for prev_event_id in prev_ids: - hashes = self._get_event_reference_hashes_txn(txn, prev_event_id) + for event_id in event_ids: + hashes = self._get_event_reference_hashes_txn(txn, event_id) prev_hashes = { k: encode_base64(v) for k, v in hashes.items() if k == "sha256" @@ -90,6 +92,53 @@ class EventFederationStore(SQLBaseStore): return results + def _get_prev_events(self, txn, event_id): + results = self._get_prev_events_and_state( + txn, + event_id, + is_state=0, + ) + + return [(e_id, h, ) for e_id, h, _ in results] + + def _get_prev_state(self, txn, event_id): + results = self._get_prev_events_and_state( + txn, + event_id, + is_state=1, + ) + + return [(e_id, h, ) for e_id, h, _ in results] + + def _get_prev_events_and_state(self, txn, event_id, is_state=None): + keyvalues = { + "event_id": event_id, + } + + if is_state is not None: + keyvalues["is_state"] = is_state + + res = self._simple_select_list_txn( + txn, + table="event_edges", + keyvalues=keyvalues, + retcols=["prev_event_id", "is_state"], + ) + + results = [] + for d in res: + hashes = self._get_event_reference_hashes_txn( + txn, + d["prev_event_id"] + ) + prev_hashes = { + k: encode_base64(v) for k, v in hashes.items() + if k == "sha256" + } + results.append((d["prev_event_id"], prev_hashes, d["is_state"])) + + return results + def get_min_depth(self, room_id): return self.runInteraction( "get_min_depth", @@ -135,6 +184,7 @@ class EventFederationStore(SQLBaseStore): "event_id": event_id, "prev_event_id": e_id, "room_id": room_id, + "is_state": 0, }, or_ignore=True, ) diff --git a/synapse/storage/schema/event_edges.sql b/synapse/storage/schema/event_edges.sql index e5f768c705..51695826a8 100644 --- a/synapse/storage/schema/event_edges.sql +++ b/synapse/storage/schema/event_edges.sql @@ -1,7 +1,7 @@ CREATE TABLE IF NOT EXISTS event_forward_extremities( - event_id TEXT, - room_id TEXT, + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE ); @@ -10,8 +10,8 @@ CREATE INDEX IF NOT EXISTS ev_extrem_id ON event_forward_extremities(event_id); CREATE TABLE IF NOT EXISTS event_backward_extremities( - event_id TEXT, - room_id TEXT, + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE ); @@ -20,10 +20,11 @@ CREATE INDEX IF NOT EXISTS ev_b_extrem_id ON event_backward_extremities(event_id CREATE TABLE IF NOT EXISTS event_edges( - event_id TEXT, - prev_event_id TEXT, - room_id TEXT, - CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id) + event_id TEXT NOT NULL, + prev_event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + is_state INTEGER NOT NULL, + CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id, is_state) ); CREATE INDEX IF NOT EXISTS ev_edges_id ON event_edges(event_id); @@ -31,8 +32,8 @@ CREATE INDEX IF NOT EXISTS ev_edges_prev_id ON event_edges(prev_event_id); CREATE TABLE IF NOT EXISTS room_depth( - room_id TEXT, - min_depth INTEGER, + room_id TEXT NOT NULL, + min_depth INTEGER NOT NULL, CONSTRAINT uniqueness UNIQUE (room_id) ); @@ -40,10 +41,25 @@ CREATE INDEX IF NOT EXISTS room_depth_room ON room_depth(room_id); create TABLE IF NOT EXISTS event_destinations( - event_id TEXT, - destination TEXT, + event_id TEXT NOT NULL, + destination TEXT NOT NULL, delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered CONSTRAINT uniqueness UNIQUE (event_id, destination) ON CONFLICT REPLACE ); CREATE INDEX IF NOT EXISTS event_destinations_id ON event_destinations(event_id); + + +CREATE TABLE IF NOT EXISTS state_forward_extremities( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE +); + +CREATE INDEX IF NOT EXISTS st_extrem_keys ON state_forward_extremities( + room_id, type, state_key +); +CREATE INDEX IF NOT EXISTS st_extrem_id ON state_forward_extremities(event_id); + diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py index c91eb897a8..e79b68f661 100644 --- a/synapse/util/jsonobject.py +++ b/synapse/util/jsonobject.py @@ -80,7 +80,7 @@ class JsonEncodedObject(object): def get_full_dict(self): d = { - k: v for (k, v) in self.__dict__.items() + k: _encode(v) for (k, v) in self.__dict__.items() if k in self.valid_keys or k in self.internal_keys } d.update(self.unrecognized_keys)