Issue one time keys in upload order (#17903)

Currently, one-time-keys are issued in a somewhat random order. (In
practice, they are issued according to the lexicographical order of
their key IDs.) That can lead to a situation where a client gives up
hope of a given OTK ever being used, whilst it is still on the server.

Related: https://github.com/element-hq/element-meta/issues/2356
This commit is contained in:
Richard van der Hoff 2024-11-06 22:21:06 +00:00 committed by GitHub
parent eda735e4bb
commit 2a321bac35
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 116 additions and 8 deletions

1
changelog.d/17903.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a long-standing bug in Synapse which could cause one-time keys to be issued in the incorrect order, causing message decryption failures.

View file

@ -615,7 +615,7 @@ class E2eKeysHandler:
3. Attempt to fetch fallback keys from the database. 3. Attempt to fetch fallback keys from the database.
Args: Args:
local_query: An iterable of tuples of (user ID, device ID, algorithm). local_query: An iterable of tuples of (user ID, device ID, algorithm, number of keys).
always_include_fallback_keys: True to always include fallback keys. always_include_fallback_keys: True to always include fallback keys.
Returns: Returns:

View file

@ -99,6 +99,13 @@ class EndToEndKeyBackgroundStore(SQLBaseStore):
unique=True, unique=True,
) )
self.db_pool.updates.register_background_index_update(
update_name="add_otk_ts_added_index",
index_name="e2e_one_time_keys_json_user_id_device_id_algorithm_ts_added_idx",
table="e2e_one_time_keys_json",
columns=("user_id", "device_id", "algorithm", "ts_added_ms"),
)
class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore): class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore):
def __init__( def __init__(
@ -1122,7 +1129,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
"""Take a list of one time keys out of the database. """Take a list of one time keys out of the database.
Args: Args:
query_list: An iterable of tuples of (user ID, device ID, algorithm). query_list: An iterable of tuples of (user ID, device ID, algorithm, number of keys).
Returns: Returns:
A tuple (results, missing) of: A tuple (results, missing) of:
@ -1310,9 +1317,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
OTK was found. OTK was found.
""" """
# Return the oldest keys from this device (based on `ts_added_ms`).
# Doing so means that keys are issued in the same order they were uploaded,
# which reduces the chances of a client expiring its copy of a (private)
# key while the public key is still on the server, waiting to be issued.
sql = """ sql = """
SELECT key_id, key_json FROM e2e_one_time_keys_json SELECT key_id, key_json FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ? WHERE user_id = ? AND device_id = ? AND algorithm = ?
ORDER BY ts_added_ms
LIMIT ? LIMIT ?
""" """
@ -1354,13 +1366,22 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
A list of tuples (user_id, device_id, algorithm, key_id, key_json) A list of tuples (user_id, device_id, algorithm, key_id, key_json)
for each OTK claimed. for each OTK claimed.
""" """
# Find, delete, and return the oldest keys from each device (based on
# `ts_added_ms`).
#
# Doing so means that keys are issued in the same order they were uploaded,
# which reduces the chances of a client expiring its copy of a (private)
# key while the public key is still on the server, waiting to be issued.
sql = """ sql = """
WITH claims(user_id, device_id, algorithm, claim_count) AS ( WITH claims(user_id, device_id, algorithm, claim_count) AS (
VALUES ? VALUES ?
), ranked_keys AS ( ), ranked_keys AS (
SELECT SELECT
user_id, device_id, algorithm, key_id, claim_count, user_id, device_id, algorithm, key_id, claim_count,
ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r ROW_NUMBER() OVER (
PARTITION BY (user_id, device_id, algorithm)
ORDER BY ts_added_ms
) AS r
FROM e2e_one_time_keys_json FROM e2e_one_time_keys_json
JOIN claims USING (user_id, device_id, algorithm) JOIN claims USING (user_id, device_id, algorithm)
) )

View file

@ -0,0 +1,18 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2024 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
-- Add an index on (user_id, device_id, algorithm, ts_added_ms) on e2e_one_time_keys_json, so that OTKs can
-- efficiently be issued in the same order they were uploaded.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(8803, 'add_otk_ts_added_index', '{}');

View file

@ -151,18 +151,30 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
def test_claim_one_time_key(self) -> None: def test_claim_one_time_key(self) -> None:
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
device_id = "xyz" device_id = "xyz"
keys = {"alg1:k1": "key1"}
res = self.get_success( res = self.get_success(
self.handler.upload_keys_for_user( self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys} local_user, device_id, {"one_time_keys": {"alg1:k1": "key1"}}
) )
) )
self.assertDictEqual( self.assertDictEqual(
res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}} res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}}
) )
res2 = self.get_success( # Keys should be returned in the order they were uploaded. To test, advance time
# a little, then upload a second key with an earlier key ID; it should get
# returned second.
self.reactor.advance(1)
res = self.get_success(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": {"alg1:k0": "key0"}}
)
)
self.assertDictEqual(
res, {"one_time_key_counts": {"alg1": 2, "signed_curve25519": 0}}
)
# now claim both keys back. They should be in the same order
res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}}, {local_user: {device_id: {"alg1": 1}}},
self.requester, self.requester,
@ -171,12 +183,27 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
) )
) )
self.assertEqual( self.assertEqual(
res2, res,
{ {
"failures": {}, "failures": {},
"one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}}, "one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}},
}, },
) )
res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
)
self.assertEqual(
res,
{
"failures": {},
"one_time_keys": {local_user: {device_id: {"alg1:k0": "key0"}}},
},
)
def test_claim_one_time_key_bulk(self) -> None: def test_claim_one_time_key_bulk(self) -> None:
"""Like test_claim_one_time_key but claims multiple keys in one handler call.""" """Like test_claim_one_time_key but claims multiple keys in one handler call."""
@ -336,6 +363,47 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
counts_by_alg, expected_counts_by_alg, f"{user_id}:{device_id}" counts_by_alg, expected_counts_by_alg, f"{user_id}:{device_id}"
) )
def test_claim_one_time_key_bulk_ordering(self) -> None:
"""Keys returned by the bulk claim call should be returned in the correct order"""
# Alice has lots of keys, uploaded in a specific order
alice = f"@alice:{self.hs.hostname}"
alice_dev = "alice_dev_1"
self.get_success(
self.handler.upload_keys_for_user(
alice,
alice_dev,
{"one_time_keys": {"alg1:k20": 20, "alg1:k21": 21, "alg1:k22": 22}},
)
)
# Advance time by 1s, to ensure that there is a difference in upload time.
self.reactor.advance(1)
self.get_success(
self.handler.upload_keys_for_user(
alice,
alice_dev,
{"one_time_keys": {"alg1:k10": 10, "alg1:k11": 11, "alg1:k12": 12}},
)
)
# Now claim some, and check we get the right ones.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{alice: {alice_dev: {"alg1": 2}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
)
# We should get the first-uploaded keys, even though they have later key ids.
# We should get a random set of two of k20, k21, k22.
self.assertEqual(claim_res["failures"], {})
claimed_keys = claim_res["one_time_keys"]["@alice:test"]["alice_dev_1"]
self.assertEqual(len(claimed_keys), 2)
for key_id in claimed_keys.keys():
self.assertIn(key_id, ["alg1:k20", "alg1:k21", "alg1:k22"])
def test_fallback_key(self) -> None: def test_fallback_key(self) -> None:
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
device_id = "xyz" device_id = "xyz"