From 84abdfdb3052f970814371e1cea3ef0a258be2d2 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 14 Aug 2019 17:43:22 -0700 Subject: [PATCH] add hash and count to key backup endpoints --- synapse/handlers/e2e_room_keys.py | 76 ++++---- synapse/rest/client/v2_alpha/room_keys.py | 4 +- synapse/storage/e2e_room_keys.py | 171 +++++++++++++----- .../storage/schema/delta/56/room_key_hash.sql | 17 ++ tests/handlers/test_e2e_room_keys.py | 29 +++ 5 files changed, 213 insertions(+), 84 deletions(-) create mode 100644 synapse/storage/schema/delta/56/room_key_hash.sql diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 41b871fc59..1666222de8 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017, 2018 New Vector Ltd +# Copyright 2019 Matrix.org Foundation CIC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -166,43 +167,46 @@ class E2eRoomKeysHandler(object): else: raise - # go through the room_keys. - # XXX: this should/could be done concurrently, given we're in a lock. + # Fetch any existing room keys for the sessions that have been + # submitted. Then compare them with the submitted keys. If the + # key is new, insert it; if the key should be updated, then update + # it; otherwise, drop it. + existing_keys = yield self.store.get_e2e_room_keys_multi( + user_id, version, room_keys["rooms"] + ) + to_insert = [] # batch the inserts together + changed = False # if anything has changed, we need to update the hash for room_id, room in iteritems(room_keys["rooms"]): - for session_id, session in iteritems(room["sessions"]): - yield self._upload_room_key( - user_id, version, room_id, session_id, session - ) + for session_id, room_key in iteritems(room["sessions"]): + current_room_key = existing_keys.get(room_id, {}).get(session_id) + if current_room_key: + if self._should_replace_room_key(current_room_key, room_key): + # updates are done one at a time in the DB, so send + # updates right away rather than batching them up, + # like we do with the inserts + yield self.store.update_e2e_room_key( + user_id, version, room_id, session_id, room_key + ) + changed = True + else: + to_insert.append((room_id, session_id, room_key)) + changed = True - @defer.inlineCallbacks - def _upload_room_key(self, user_id, version, room_id, session_id, room_key): - """Upload a given room_key for a given room and session into a given - version of the backup. Merges the key with any which might already exist. + if len(to_insert): + yield self.store.add_e2e_room_keys(user_id, version, to_insert) - Args: - user_id(str): the user whose backup we're setting - version(str): the version ID of the backup we're updating - room_id(str): the ID of the room whose keys we're setting - session_id(str): the session whose room_key we're setting - room_key(dict): the room_key being set - """ + version_hash = int(version_info["hash"]) + if changed: + version_hash = version_hash + 1 + yield self.store.update_e2e_room_keys_version( + user_id, version, {"hash": str(version_hash)} + ) - # get the room_key for this particular row - current_room_key = None - try: - current_room_key = yield self.store.get_e2e_room_key( - user_id, version, room_id, session_id - ) - except StoreError as e: - if e.code == 404: - pass - else: - raise - - if self._should_replace_room_key(current_room_key, room_key): - yield self.store.set_e2e_room_key( - user_id, version, room_id, session_id, room_key - ) + count = yield self.store.count_e2e_room_keys(user_id, version) + return { + "hash": version_hash, + "count": count + } @staticmethod def _should_replace_room_key(current_room_key, room_key): @@ -292,6 +296,8 @@ class E2eRoomKeysHandler(object): raise NotFoundError("Unknown backup version") else: raise + + res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"]) return res @defer.inlineCallbacks @@ -346,6 +352,10 @@ class E2eRoomKeysHandler(object): if old_info["algorithm"] != version_info["algorithm"]: raise SynapseError(400, "Algorithm does not match", Codes.INVALID_PARAM) + # don't allow the user to set the hash + if "hash" in version_info: + del version_info["hash"] + yield self.store.update_e2e_room_keys_version( user_id, version, version_info ) diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index 10dec96208..3b7252fb7e 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -134,8 +134,8 @@ class RoomKeysServlet(RestServlet): if room_id: body = {"rooms": {room_id: body}} - yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) - return (200, {}) + ret = yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) + return (200, ret) @defer.inlineCallbacks def on_GET(self, request, room_id, session_id): diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index 99128f2df7..47c773580b 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017 New Vector Ltd +# Copyright 2019 Matrix.org Foundation CIC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,49 +25,8 @@ from ._base import SQLBaseStore class EndToEndRoomKeyStore(SQLBaseStore): @defer.inlineCallbacks - def get_e2e_room_key(self, user_id, version, room_id, session_id): - """Get the encrypted E2E room key for a given session from a given - backup version of room_keys. We only store the 'best' room key for a given - session at a given time, as determined by the handler. - - Args: - user_id(str): the user whose backup we're querying - version(str): the version ID of the backup for the set of keys we're querying - room_id(str): the ID of the room whose keys we're querying. - This is a bit redundant as it's implied by the session_id, but - we include for consistency with the rest of the API. - session_id(str): the session whose room_key we're querying. - - Returns: - A deferred dict giving the session_data and message metadata for - this room key. - """ - - row = yield self._simple_select_one( - table="e2e_room_keys", - keyvalues={ - "user_id": user_id, - "version": version, - "room_id": room_id, - "session_id": session_id, - }, - retcols=( - "first_message_index", - "forwarded_count", - "is_verified", - "session_data", - ), - desc="get_e2e_room_key", - ) - - row["session_data"] = json.loads(row["session_data"]) - - return row - - @defer.inlineCallbacks - def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key): - """Replaces or inserts the encrypted E2E room key for a given session in - a given backup + def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key): + """Replaces the encrypted E2E room key for a given session in a given backup Args: user_id(str): the user whose backup we're setting @@ -78,21 +38,46 @@ class EndToEndRoomKeyStore(SQLBaseStore): StoreError """ - yield self._simple_upsert( + yield self._simple_update_one( table="e2e_room_keys", keyvalues={ "user_id": user_id, + "version": version, "room_id": room_id, "session_id": session_id, }, - values={ - "version": version, + updatevalues={ "first_message_index": room_key["first_message_index"], "forwarded_count": room_key["forwarded_count"], "is_verified": room_key["is_verified"], "session_data": json.dumps(room_key["session_data"]), }, - lock=False, + desc="update_e2e_room_key" + ) + + @defer.inlineCallbacks + def add_e2e_room_keys(self, user_id, version, room_keys): + """Bulk add room keys to a given backup. + + Args: + user_id(str): the user whose backup we're adding to + version(str): the version ID of the backup for the set of keys we're adding to + room_keys(iterable[dict]): the keys to add + """ + + yield self._simple_insert_many( + table="e2e_room_keys", + values=[{ + "user_id": user_id, + "version": version, + "room_id": room_id, + "session_id": session_id, + "first_message_index": room_key["first_message_index"], + "forwarded_count": room_key["forwarded_count"], + "is_verified": room_key["is_verified"], + "session_data": json.dumps(room_key["session_data"]), + } for (room_id, session_id, room_key) in room_keys], + desc="add_e2e_room_keys" ) @defer.inlineCallbacks @@ -153,6 +138,86 @@ class EndToEndRoomKeyStore(SQLBaseStore): return sessions + def get_e2e_room_keys_multi(self, user_id, version, room_keys): + """Get multiple room keys at a time. The difference between this function and + get_e2e_room_keys is that this function can be used to retrieve + multiple specific keys at a time, whereas get_e2e_room_keys is used for + getting all the keys in a backup version, all the keys for a room, or a + specific key. + + Args: + user_id(str): the user whose backup we're querying + version(str): the version ID of the backup we're querying about + room_keys(dict[dict[iterable[str]]]): a map of room IDs to dict which + has a "session" key that is an iterable of session IDs that we + want to query + + Returns: + dict[dict[dict]]: a map of room IDs to session IDs to room key + """ + + return self.runInteraction( + "get_e2e_room_keys_multi", self._get_e2e_room_keys_multi_txn, + user_id, version, room_keys + ) + + @staticmethod + def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys): + if not len(room_keys): + return {} + + where_clauses = [] + params = [user_id, version] + for room_id, room in room_keys.items(): + sessions = list(room["sessions"]) + if not len(sessions): + continue + params.append(room_id) + params.extend(sessions) + where_clauses.append( + "(room_id = ? AND session_id IN (%s))" + % (",".join(["?" for _ in sessions]),) + ) + + sql = """ + SELECT room_id, session_id, first_message_index, forwarded_count, + is_verified, session_data + FROM e2e_room_keys + WHERE user_id = ? AND version = ? AND (%s) + """ % (" OR ".join(where_clauses)) + + txn.execute(sql, params) + + ret = {} + + for row in txn: + room_id = row[0] + session_id = row[1] + ret.setdefault(room_id, {}) + ret[room_id][session_id] = { + "first_message_index": row[2], + "forwarded_count": row[3], + "is_verified": row[4], + "session_data": json.loads(row[5]), + } + + return ret + + def count_e2e_room_keys(self, user_id, version): + """Get the number of keys in a backup version. + + Args: + user_id(str): the user whose backup we're querying + version(str): the version ID of the backup we're querying about + """ + + return self._simple_select_one_onecol( + table="e2e_room_keys", + keyvalues={"user_id": user_id, "version": version}, + retcol="COUNT(*)", + desc="count_e2e_room_keys", + ) + @defer.inlineCallbacks def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): """Bulk delete the E2E room keys for a given backup, optionally filtered to a given @@ -226,10 +291,12 @@ class EndToEndRoomKeyStore(SQLBaseStore): txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, - retcols=("version", "algorithm", "auth_data"), + retcols=("version", "algorithm", "auth_data", "hash"), ) result["auth_data"] = json.loads(result["auth_data"]) result["version"] = str(result["version"]) + if not result["hash"]: + result["hash"] = 0 return result return self.runInteraction( @@ -284,11 +351,17 @@ class EndToEndRoomKeyStore(SQLBaseStore): version(str): the version ID of the backup version we're updating info(dict): the new backup version info to store """ + updatevalues = {} + + if "auth_data" in info: + updatevalues["auth_data"] = json.dumps(info["auth_data"]) + if "hash" in info: + updatevalues["hash"] = info["hash"] return self._simple_update( table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": version}, - updatevalues={"auth_data": json.dumps(info["auth_data"])}, + updatevalues=updatevalues, desc="update_e2e_room_keys_version", ) diff --git a/synapse/storage/schema/delta/56/room_key_hash.sql b/synapse/storage/schema/delta/56/room_key_hash.sql new file mode 100644 index 0000000000..fa61cd3ff0 --- /dev/null +++ b/synapse/storage/schema/delta/56/room_key_hash.sql @@ -0,0 +1,17 @@ +/* Copyright 2019 Matrix.org Foundation CIC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- store the current hash of backup version +ALTER TABLE e2e_room_keys_versions ADD COLUMN hash TEXT NULLABLE; diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index c4503c1611..0e3bf70749 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd # Copyright 2017 New Vector Ltd +# Copyright 2019 Matrix.org Foundation CIC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -94,23 +95,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + version_hash = res["hash"] + del res["hash"] self.assertDictEqual( res, { "version": "1", "algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data", + "count": 0, }, ) # check we can retrieve it as a specific version res = yield self.handler.get_version_info(self.local_user, "1") + self.assertEqual(res["hash"], version_hash) + del res["hash"] self.assertDictEqual( res, { "version": "1", "algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data", + "count": 0, }, ) @@ -126,12 +133,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + del res["hash"] self.assertDictEqual( res, { "version": "2", "algorithm": "m.megolm_backup.v1", "auth_data": "second_version_auth_data", + "count": 0, }, ) @@ -158,12 +167,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + del res["hash"] self.assertDictEqual( res, { "algorithm": "m.megolm_backup.v1", "auth_data": "revised_first_version_auth_data", "version": version, + "count": 0, }, ) @@ -394,6 +405,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): yield self.handler.upload_room_keys(self.local_user, version, room_keys) + # get the hash to compare to future versions + res = yield self.handler.get_version_info(self.local_user) + backup_hash = res["hash"] + self.assertEqual(res["count"], 1) + new_room_keys = copy.deepcopy(room_keys) new_room_key = new_room_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"] @@ -408,6 +424,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): "SSBBTSBBIEZJU0gK", ) + # the hash should be the same since the session did not change + res = yield self.handler.get_version_info(self.local_user) + self.assertEqual(res["hash"], backup_hash) + # test that marking the session as verified however /does/ replace it new_room_key["is_verified"] = True yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) @@ -417,6 +437,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) + # the hash should NOT be equal now, since the key changed + res = yield self.handler.get_version_info(self.local_user) + self.assertNotEqual(res["hash"], backup_hash) + backup_hash = res["hash"] + # test that a session with a higher forwarded_count doesn't replace one # with a lower forwarding count new_room_key["forwarded_count"] = 2 @@ -428,6 +453,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) + # the hash should be the same since the session did not change + res = yield self.handler.get_version_info(self.local_user) + self.assertEqual(res["hash"], backup_hash) + # TODO: check edge cases as well as the common variations here @defer.inlineCallbacks