Claim fallback keys in bulk (#16570)

This commit is contained in:
David Robertson 2023-10-30 14:34:37 +00:00 committed by GitHub
parent a3f6200d65
commit fdce83ee60
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 162 additions and 0 deletions

View file

@ -0,0 +1 @@
Improve the performance of claiming encryption keys.

View file

@ -659,6 +659,20 @@ class E2eKeysHandler:
timeout: Optional[int],
always_include_fallback_keys: bool,
) -> JsonDict:
"""
Args:
query: A chain of maps from (user_id, device_id, algorithm) to the requested
number of keys to claim.
user: The user who is claiming these keys.
timeout: How long to wait for any federation key claim requests before
giving up.
always_include_fallback_keys: always include a fallback key for local users'
devices, even if we managed to claim a one-time-key.
Returns: a heterogeneous dict with two keys:
one_time_keys: chain of maps user ID -> device ID -> key ID -> key.
failures: map from remote destination to a JsonDict describing the error.
"""
local_query: List[Tuple[str, str, str, int]] = []
remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {}

View file

@ -420,6 +420,16 @@ class LoggingTransaction:
self._do_execute(self.txn.execute, sql, parameters)
def executemany(self, sql: str, *args: Any) -> None:
"""Repeatedly execute the same piece of SQL with different parameters.
See https://peps.python.org/pep-0249/#executemany. Note in particular that
> Use of this method for an operation which produces one or more result sets
> constitutes undefined behavior
so you can't use this for e.g. a SELECT, an UPDATE ... RETURNING, or a
DELETE FROM... RETURNING.
"""
# TODO: we should add a type for *args here. Looking at Cursor.executemany
# and DBAPI2 it ought to be Sequence[_Parameter], but we pass in
# Iterable[Iterable[Any]] in execute_batch and execute_values above, which mypy

View file

@ -24,6 +24,7 @@ from typing import (
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
@ -1260,6 +1261,65 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
Returns:
A map of user ID -> a map device ID -> a map of key ID -> JSON.
"""
if isinstance(self.database_engine, PostgresEngine):
return await self.db_pool.runInteraction(
"_claim_e2e_fallback_keys_bulk",
self._claim_e2e_fallback_keys_bulk_txn,
query_list,
db_autocommit=True,
)
# Use an UPDATE FROM... RETURNING combined with a VALUES block to do
# everything in one query. Note: this is also supported in SQLite 3.33.0,
# (see https://www.sqlite.org/lang_update.html#update_from), but we do not
# have an equivalent of psycopg2's execute_values to do this in one query.
else:
return await self._claim_e2e_fallback_keys_simple(query_list)
def _claim_e2e_fallback_keys_bulk_txn(
self,
txn: LoggingTransaction,
query_list: Iterable[Tuple[str, str, str, bool]],
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
"""Efficient implementation of claim_e2e_fallback_keys for Postgres.
Safe to autocommit: this is a single query.
"""
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
sql = """
WITH claims(user_id, device_id, algorithm, mark_as_used) AS (
VALUES ?
)
UPDATE e2e_fallback_keys_json k
SET used = used OR mark_as_used
FROM claims
WHERE (k.user_id, k.device_id, k.algorithm) = (claims.user_id, claims.device_id, claims.algorithm)
RETURNING k.user_id, k.device_id, k.algorithm, k.key_id, k.key_json;
"""
claimed_keys = cast(
List[Tuple[str, str, str, str, str]],
txn.execute_values(sql, query_list),
)
seen_user_device: Set[Tuple[str, str]] = set()
for user_id, device_id, algorithm, key_id, key_json in claimed_keys:
device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
if (user_id, device_id) in seen_user_device:
continue
seen_user_device.add((user_id, device_id))
self._invalidate_cache_and_stream(
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
)
return results
async def _claim_e2e_fallback_keys_simple(
self,
query_list: Iterable[Tuple[str, str, str, bool]],
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
"""Naive, inefficient implementation of claim_e2e_fallback_keys for SQLite."""
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, device_id, algorithm, mark_as_used in query_list:
row = await self.db_pool.simple_select_one(

View file

@ -322,6 +322,83 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
)
def test_fallback_key_bulk(self) -> None:
"""Like test_fallback_key, but claims multiple keys in one handler call."""
alice = f"@alice:{self.hs.hostname}"
brian = f"@brian:{self.hs.hostname}"
chris = f"@chris:{self.hs.hostname}"
# Have three users upload fallback keys for two devices.
fallback_keys = {
alice: {
"alice_dev_1": {"alg1:k1": "fallback_key1"},
"alice_dev_2": {"alg2:k2": "fallback_key2"},
},
brian: {
"brian_dev_1": {"alg1:k3": "fallback_key3"},
"brian_dev_2": {"alg2:k4": "fallback_key4"},
},
chris: {
"chris_dev_1": {"alg1:k5": "fallback_key5"},
"chris_dev_2": {"alg2:k6": "fallback_key6"},
},
}
for user_id, devices in fallback_keys.items():
for device_id, key_dict in devices.items():
self.get_success(
self.handler.upload_keys_for_user(
user_id,
device_id,
{"fallback_keys": key_dict},
)
)
# Each device should have an unused fallback key.
for user_id, devices in fallback_keys.items():
for device_id in devices:
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
)
expected_algorithm_name = f"alg{device_id[-1]}"
self.assertEqual(fallback_res, [expected_algorithm_name])
# Claim the fallback key for one device per user.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{
alice: {"alice_dev_1": {"alg1": 1}},
brian: {"brian_dev_2": {"alg2": 1}},
chris: {"chris_dev_2": {"alg2": 1}},
},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
)
expected_claims = {
alice: {"alice_dev_1": {"alg1:k1": "fallback_key1"}},
brian: {"brian_dev_2": {"alg2:k4": "fallback_key4"}},
chris: {"chris_dev_2": {"alg2:k6": "fallback_key6"}},
}
self.assertEqual(
claim_res,
{"failures": {}, "one_time_keys": expected_claims},
)
for user_id, devices in fallback_keys.items():
for device_id in devices:
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
)
# Claimed fallback keys should no longer show up as unused.
# Unclaimed fallback keys should still be unused.
if device_id in expected_claims[user_id]:
self.assertEqual(fallback_res, [])
else:
expected_algorithm_name = f"alg{device_id[-1]}"
self.assertEqual(fallback_res, [expected_algorithm_name])
def test_fallback_key_always_returned(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"