From 2a321bac35b872d47d8ae8da4cba31d757e96a26 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 6 Nov 2024 22:21:06 +0000 Subject: [PATCH] Issue one time keys in upload order (#17903) Currently, one-time-keys are issued in a somewhat random order. (In practice, they are issued according to the lexicographical order of their key IDs.) That can lead to a situation where a client gives up hope of a given OTK ever being used, whilst it is still on the server. Related: https://github.com/element-hq/element-meta/issues/2356 --- changelog.d/17903.bugfix | 1 + synapse/handlers/e2e_keys.py | 2 +- .../storage/databases/main/end_to_end_keys.py | 25 +++++- .../delta/88/03_add_otk_ts_added_index.sql | 18 +++++ tests/handlers/test_e2e_keys.py | 78 +++++++++++++++++-- 5 files changed, 116 insertions(+), 8 deletions(-) create mode 100644 changelog.d/17903.bugfix create mode 100644 synapse/storage/schema/main/delta/88/03_add_otk_ts_added_index.sql diff --git a/changelog.d/17903.bugfix b/changelog.d/17903.bugfix new file mode 100644 index 0000000000..a4d02fc983 --- /dev/null +++ b/changelog.d/17903.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug in Synapse which could cause one-time keys to be issued in the incorrect order, causing message decryption failures. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index f78e66ad0a..315461fefb 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -615,7 +615,7 @@ class E2eKeysHandler: 3. Attempt to fetch fallback keys from the database. Args: - local_query: An iterable of tuples of (user ID, device ID, algorithm). + local_query: An iterable of tuples of (user ID, device ID, algorithm, number of keys). always_include_fallback_keys: True to always include fallback keys. Returns: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 575aaf498b..1fbc49e7c5 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -99,6 +99,13 @@ class EndToEndKeyBackgroundStore(SQLBaseStore): unique=True, ) + self.db_pool.updates.register_background_index_update( + update_name="add_otk_ts_added_index", + index_name="e2e_one_time_keys_json_user_id_device_id_algorithm_ts_added_idx", + table="e2e_one_time_keys_json", + columns=("user_id", "device_id", "algorithm", "ts_added_ms"), + ) + class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore): def __init__( @@ -1122,7 +1129,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker """Take a list of one time keys out of the database. Args: - query_list: An iterable of tuples of (user ID, device ID, algorithm). + query_list: An iterable of tuples of (user ID, device ID, algorithm, number of keys). Returns: A tuple (results, missing) of: @@ -1310,9 +1317,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker OTK was found. """ + # Return the oldest keys from this device (based on `ts_added_ms`). + # Doing so means that keys are issued in the same order they were uploaded, + # which reduces the chances of a client expiring its copy of a (private) + # key while the public key is still on the server, waiting to be issued. sql = """ SELECT key_id, key_json FROM e2e_one_time_keys_json WHERE user_id = ? AND device_id = ? AND algorithm = ? + ORDER BY ts_added_ms LIMIT ? """ @@ -1354,13 +1366,22 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker A list of tuples (user_id, device_id, algorithm, key_id, key_json) for each OTK claimed. """ + # Find, delete, and return the oldest keys from each device (based on + # `ts_added_ms`). + # + # Doing so means that keys are issued in the same order they were uploaded, + # which reduces the chances of a client expiring its copy of a (private) + # key while the public key is still on the server, waiting to be issued. sql = """ WITH claims(user_id, device_id, algorithm, claim_count) AS ( VALUES ? ), ranked_keys AS ( SELECT user_id, device_id, algorithm, key_id, claim_count, - ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r + ROW_NUMBER() OVER ( + PARTITION BY (user_id, device_id, algorithm) + ORDER BY ts_added_ms + ) AS r FROM e2e_one_time_keys_json JOIN claims USING (user_id, device_id, algorithm) ) diff --git a/synapse/storage/schema/main/delta/88/03_add_otk_ts_added_index.sql b/synapse/storage/schema/main/delta/88/03_add_otk_ts_added_index.sql new file mode 100644 index 0000000000..7712ea68ad --- /dev/null +++ b/synapse/storage/schema/main/delta/88/03_add_otk_ts_added_index.sql @@ -0,0 +1,18 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2024 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + + +-- Add an index on (user_id, device_id, algorithm, ts_added_ms) on e2e_one_time_keys_json, so that OTKs can +-- efficiently be issued in the same order they were uploaded. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (8803, 'add_otk_ts_added_index', '{}'); diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 8a3dfdcf75..bca314db83 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -151,18 +151,30 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): def test_claim_one_time_key(self) -> None: local_user = "@boris:" + self.hs.hostname device_id = "xyz" - keys = {"alg1:k1": "key1"} - res = self.get_success( self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + local_user, device_id, {"one_time_keys": {"alg1:k1": "key1"}} ) ) self.assertDictEqual( res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}} ) - res2 = self.get_success( + # Keys should be returned in the order they were uploaded. To test, advance time + # a little, then upload a second key with an earlier key ID; it should get + # returned second. + self.reactor.advance(1) + res = self.get_success( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": {"alg1:k0": "key0"}} + ) + ) + self.assertDictEqual( + res, {"one_time_key_counts": {"alg1": 2, "signed_curve25519": 0}} + ) + + # now claim both keys back. They should be in the same order + res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id: {"alg1": 1}}}, self.requester, @@ -171,12 +183,27 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): ) ) self.assertEqual( - res2, + res, { "failures": {}, "one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}}, }, ) + res = self.get_success( + self.handler.claim_one_time_keys( + {local_user: {device_id: {"alg1": 1}}}, + self.requester, + timeout=None, + always_include_fallback_keys=False, + ) + ) + self.assertEqual( + res, + { + "failures": {}, + "one_time_keys": {local_user: {device_id: {"alg1:k0": "key0"}}}, + }, + ) def test_claim_one_time_key_bulk(self) -> None: """Like test_claim_one_time_key but claims multiple keys in one handler call.""" @@ -336,6 +363,47 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): counts_by_alg, expected_counts_by_alg, f"{user_id}:{device_id}" ) + def test_claim_one_time_key_bulk_ordering(self) -> None: + """Keys returned by the bulk claim call should be returned in the correct order""" + + # Alice has lots of keys, uploaded in a specific order + alice = f"@alice:{self.hs.hostname}" + alice_dev = "alice_dev_1" + + self.get_success( + self.handler.upload_keys_for_user( + alice, + alice_dev, + {"one_time_keys": {"alg1:k20": 20, "alg1:k21": 21, "alg1:k22": 22}}, + ) + ) + # Advance time by 1s, to ensure that there is a difference in upload time. + self.reactor.advance(1) + self.get_success( + self.handler.upload_keys_for_user( + alice, + alice_dev, + {"one_time_keys": {"alg1:k10": 10, "alg1:k11": 11, "alg1:k12": 12}}, + ) + ) + + # Now claim some, and check we get the right ones. + claim_res = self.get_success( + self.handler.claim_one_time_keys( + {alice: {alice_dev: {"alg1": 2}}}, + self.requester, + timeout=None, + always_include_fallback_keys=False, + ) + ) + # We should get the first-uploaded keys, even though they have later key ids. + # We should get a random set of two of k20, k21, k22. + self.assertEqual(claim_res["failures"], {}) + claimed_keys = claim_res["one_time_keys"]["@alice:test"]["alice_dev_1"] + self.assertEqual(len(claimed_keys), 2) + for key_id in claimed_keys.keys(): + self.assertIn(key_id, ["alg1:k20", "alg1:k21", "alg1:k22"]) + def test_fallback_key(self) -> None: local_user = "@boris:" + self.hs.hostname device_id = "xyz"