mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-21 09:05:42 +03:00
Issue one time keys in upload order (#17903)
Currently, one-time-keys are issued in a somewhat random order. (In practice, they are issued according to the lexicographical order of their key IDs.) That can lead to a situation where a client gives up hope of a given OTK ever being used, whilst it is still on the server. Related: https://github.com/element-hq/element-meta/issues/2356
This commit is contained in:
parent
eda735e4bb
commit
2a321bac35
5 changed files with 116 additions and 8 deletions
1
changelog.d/17903.bugfix
Normal file
1
changelog.d/17903.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix a long-standing bug in Synapse which could cause one-time keys to be issued in the incorrect order, causing message decryption failures.
|
|
@ -615,7 +615,7 @@ class E2eKeysHandler:
|
||||||
3. Attempt to fetch fallback keys from the database.
|
3. Attempt to fetch fallback keys from the database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
local_query: An iterable of tuples of (user ID, device ID, algorithm).
|
local_query: An iterable of tuples of (user ID, device ID, algorithm, number of keys).
|
||||||
always_include_fallback_keys: True to always include fallback keys.
|
always_include_fallback_keys: True to always include fallback keys.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
|
@ -99,6 +99,13 @@ class EndToEndKeyBackgroundStore(SQLBaseStore):
|
||||||
unique=True,
|
unique=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.db_pool.updates.register_background_index_update(
|
||||||
|
update_name="add_otk_ts_added_index",
|
||||||
|
index_name="e2e_one_time_keys_json_user_id_device_id_algorithm_ts_added_idx",
|
||||||
|
table="e2e_one_time_keys_json",
|
||||||
|
columns=("user_id", "device_id", "algorithm", "ts_added_ms"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore):
|
class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -1122,7 +1129,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||||
"""Take a list of one time keys out of the database.
|
"""Take a list of one time keys out of the database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query_list: An iterable of tuples of (user ID, device ID, algorithm).
|
query_list: An iterable of tuples of (user ID, device ID, algorithm, number of keys).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple (results, missing) of:
|
A tuple (results, missing) of:
|
||||||
|
@ -1310,9 +1317,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||||
OTK was found.
|
OTK was found.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Return the oldest keys from this device (based on `ts_added_ms`).
|
||||||
|
# Doing so means that keys are issued in the same order they were uploaded,
|
||||||
|
# which reduces the chances of a client expiring its copy of a (private)
|
||||||
|
# key while the public key is still on the server, waiting to be issued.
|
||||||
sql = """
|
sql = """
|
||||||
SELECT key_id, key_json FROM e2e_one_time_keys_json
|
SELECT key_id, key_json FROM e2e_one_time_keys_json
|
||||||
WHERE user_id = ? AND device_id = ? AND algorithm = ?
|
WHERE user_id = ? AND device_id = ? AND algorithm = ?
|
||||||
|
ORDER BY ts_added_ms
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -1354,13 +1366,22 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||||
A list of tuples (user_id, device_id, algorithm, key_id, key_json)
|
A list of tuples (user_id, device_id, algorithm, key_id, key_json)
|
||||||
for each OTK claimed.
|
for each OTK claimed.
|
||||||
"""
|
"""
|
||||||
|
# Find, delete, and return the oldest keys from each device (based on
|
||||||
|
# `ts_added_ms`).
|
||||||
|
#
|
||||||
|
# Doing so means that keys are issued in the same order they were uploaded,
|
||||||
|
# which reduces the chances of a client expiring its copy of a (private)
|
||||||
|
# key while the public key is still on the server, waiting to be issued.
|
||||||
sql = """
|
sql = """
|
||||||
WITH claims(user_id, device_id, algorithm, claim_count) AS (
|
WITH claims(user_id, device_id, algorithm, claim_count) AS (
|
||||||
VALUES ?
|
VALUES ?
|
||||||
), ranked_keys AS (
|
), ranked_keys AS (
|
||||||
SELECT
|
SELECT
|
||||||
user_id, device_id, algorithm, key_id, claim_count,
|
user_id, device_id, algorithm, key_id, claim_count,
|
||||||
ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r
|
ROW_NUMBER() OVER (
|
||||||
|
PARTITION BY (user_id, device_id, algorithm)
|
||||||
|
ORDER BY ts_added_ms
|
||||||
|
) AS r
|
||||||
FROM e2e_one_time_keys_json
|
FROM e2e_one_time_keys_json
|
||||||
JOIN claims USING (user_id, device_id, algorithm)
|
JOIN claims USING (user_id, device_id, algorithm)
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
--
|
||||||
|
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||||
|
--
|
||||||
|
-- Copyright (C) 2024 New Vector, Ltd
|
||||||
|
--
|
||||||
|
-- This program is free software: you can redistribute it and/or modify
|
||||||
|
-- it under the terms of the GNU Affero General Public License as
|
||||||
|
-- published by the Free Software Foundation, either version 3 of the
|
||||||
|
-- License, or (at your option) any later version.
|
||||||
|
--
|
||||||
|
-- See the GNU Affero General Public License for more details:
|
||||||
|
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||||
|
|
||||||
|
|
||||||
|
-- Add an index on (user_id, device_id, algorithm, ts_added_ms) on e2e_one_time_keys_json, so that OTKs can
|
||||||
|
-- efficiently be issued in the same order they were uploaded.
|
||||||
|
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||||
|
(8803, 'add_otk_ts_added_index', '{}');
|
|
@ -151,18 +151,30 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
def test_claim_one_time_key(self) -> None:
|
def test_claim_one_time_key(self) -> None:
|
||||||
local_user = "@boris:" + self.hs.hostname
|
local_user = "@boris:" + self.hs.hostname
|
||||||
device_id = "xyz"
|
device_id = "xyz"
|
||||||
keys = {"alg1:k1": "key1"}
|
|
||||||
|
|
||||||
res = self.get_success(
|
res = self.get_success(
|
||||||
self.handler.upload_keys_for_user(
|
self.handler.upload_keys_for_user(
|
||||||
local_user, device_id, {"one_time_keys": keys}
|
local_user, device_id, {"one_time_keys": {"alg1:k1": "key1"}}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertDictEqual(
|
self.assertDictEqual(
|
||||||
res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}}
|
res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}}
|
||||||
)
|
)
|
||||||
|
|
||||||
res2 = self.get_success(
|
# Keys should be returned in the order they were uploaded. To test, advance time
|
||||||
|
# a little, then upload a second key with an earlier key ID; it should get
|
||||||
|
# returned second.
|
||||||
|
self.reactor.advance(1)
|
||||||
|
res = self.get_success(
|
||||||
|
self.handler.upload_keys_for_user(
|
||||||
|
local_user, device_id, {"one_time_keys": {"alg1:k0": "key0"}}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertDictEqual(
|
||||||
|
res, {"one_time_key_counts": {"alg1": 2, "signed_curve25519": 0}}
|
||||||
|
)
|
||||||
|
|
||||||
|
# now claim both keys back. They should be in the same order
|
||||||
|
res = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{local_user: {device_id: {"alg1": 1}}},
|
{local_user: {device_id: {"alg1": 1}}},
|
||||||
self.requester,
|
self.requester,
|
||||||
|
@ -171,12 +183,27 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
res2,
|
res,
|
||||||
{
|
{
|
||||||
"failures": {},
|
"failures": {},
|
||||||
"one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}},
|
"one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
res = self.get_success(
|
||||||
|
self.handler.claim_one_time_keys(
|
||||||
|
{local_user: {device_id: {"alg1": 1}}},
|
||||||
|
self.requester,
|
||||||
|
timeout=None,
|
||||||
|
always_include_fallback_keys=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
res,
|
||||||
|
{
|
||||||
|
"failures": {},
|
||||||
|
"one_time_keys": {local_user: {device_id: {"alg1:k0": "key0"}}},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def test_claim_one_time_key_bulk(self) -> None:
|
def test_claim_one_time_key_bulk(self) -> None:
|
||||||
"""Like test_claim_one_time_key but claims multiple keys in one handler call."""
|
"""Like test_claim_one_time_key but claims multiple keys in one handler call."""
|
||||||
|
@ -336,6 +363,47 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
counts_by_alg, expected_counts_by_alg, f"{user_id}:{device_id}"
|
counts_by_alg, expected_counts_by_alg, f"{user_id}:{device_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_claim_one_time_key_bulk_ordering(self) -> None:
|
||||||
|
"""Keys returned by the bulk claim call should be returned in the correct order"""
|
||||||
|
|
||||||
|
# Alice has lots of keys, uploaded in a specific order
|
||||||
|
alice = f"@alice:{self.hs.hostname}"
|
||||||
|
alice_dev = "alice_dev_1"
|
||||||
|
|
||||||
|
self.get_success(
|
||||||
|
self.handler.upload_keys_for_user(
|
||||||
|
alice,
|
||||||
|
alice_dev,
|
||||||
|
{"one_time_keys": {"alg1:k20": 20, "alg1:k21": 21, "alg1:k22": 22}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Advance time by 1s, to ensure that there is a difference in upload time.
|
||||||
|
self.reactor.advance(1)
|
||||||
|
self.get_success(
|
||||||
|
self.handler.upload_keys_for_user(
|
||||||
|
alice,
|
||||||
|
alice_dev,
|
||||||
|
{"one_time_keys": {"alg1:k10": 10, "alg1:k11": 11, "alg1:k12": 12}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now claim some, and check we get the right ones.
|
||||||
|
claim_res = self.get_success(
|
||||||
|
self.handler.claim_one_time_keys(
|
||||||
|
{alice: {alice_dev: {"alg1": 2}}},
|
||||||
|
self.requester,
|
||||||
|
timeout=None,
|
||||||
|
always_include_fallback_keys=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# We should get the first-uploaded keys, even though they have later key ids.
|
||||||
|
# We should get a random set of two of k20, k21, k22.
|
||||||
|
self.assertEqual(claim_res["failures"], {})
|
||||||
|
claimed_keys = claim_res["one_time_keys"]["@alice:test"]["alice_dev_1"]
|
||||||
|
self.assertEqual(len(claimed_keys), 2)
|
||||||
|
for key_id in claimed_keys.keys():
|
||||||
|
self.assertIn(key_id, ["alg1:k20", "alg1:k21", "alg1:k22"])
|
||||||
|
|
||||||
def test_fallback_key(self) -> None:
|
def test_fallback_key(self) -> None:
|
||||||
local_user = "@boris:" + self.hs.hostname
|
local_user = "@boris:" + self.hs.hostname
|
||||||
device_id = "xyz"
|
device_id = "xyz"
|
||||||
|
|
Loading…
Reference in a new issue