Add cache to get_server_keys_json_for_remote (#16123)

This commit is contained in:
Erik Johnston 2023-08-18 11:05:01 +01:00 committed by GitHub
parent 54a51ff6c1
commit 0aba4a4eaa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 144 additions and 101 deletions

1
changelog.d/16123.misc Normal file
View file

@ -0,0 +1 @@
Add cache to `get_server_keys_json_for_remote`.

View file

@ -14,7 +14,7 @@
import logging import logging
import re import re
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple
from signedjson.sign import sign_json from signedjson.sign import sign_json
@ -27,6 +27,7 @@ from synapse.http.servlet import (
parse_integer, parse_integer,
parse_json_object_from_request, parse_json_object_from_request,
) )
from synapse.storage.keys import FetchKeyResultForRemote
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_decoder from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results from synapse.util.async_helpers import yieldable_gather_results
@ -157,14 +158,22 @@ class RemoteKey(RestServlet):
) -> JsonDict: ) -> JsonDict:
logger.info("Handling query for keys %r", query) logger.info("Handling query for keys %r", query)
store_queries = [] server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {}
for server_name, key_ids in query.items(): for server_name, key_ids in query.items():
if not key_ids: if key_ids:
key_ids = (None,) results: Mapping[
for key_id in key_ids: str, Optional[FetchKeyResultForRemote]
store_queries.append((server_name, key_id, None)) ] = await self.store.get_server_keys_json_for_remote(
server_name, key_ids
)
else:
results = await self.store.get_all_server_keys_json_for_remote(
server_name
)
cached = await self.store.get_server_keys_json_for_remote(store_queries) server_keys.update(
((server_name, key_id), res) for key_id, res in results.items()
)
json_results: Set[bytes] = set() json_results: Set[bytes] = set()
@ -173,23 +182,20 @@ class RemoteKey(RestServlet):
# Map server_name->key_id->int. Note that the value of the int is unused. # Map server_name->key_id->int. Note that the value of the int is unused.
# XXX: why don't we just use a set? # XXX: why don't we just use a set?
cache_misses: Dict[str, Dict[str, int]] = {} cache_misses: Dict[str, Dict[str, int]] = {}
for (server_name, key_id, _), key_results in cached.items(): for (server_name, key_id), key_result in server_keys.items():
results = [(result["ts_added_ms"], result) for result in key_results] if not query[server_name]:
if key_id is None:
# all keys were requested. Just return what we have without worrying # all keys were requested. Just return what we have without worrying
# about validity # about validity
for _, result in results: if key_result:
# Cast to bytes since postgresql returns a memoryview. json_results.add(key_result.key_json)
json_results.add(bytes(result["key_json"]))
continue continue
miss = False miss = False
if not results: if key_result is None:
miss = True miss = True
else: else:
ts_added_ms, most_recent_result = max(results) ts_added_ms = key_result.added_ts
ts_valid_until_ms = most_recent_result["ts_valid_until_ms"] ts_valid_until_ms = key_result.valid_until_ts
req_key = query.get(server_name, {}).get(key_id, {}) req_key = query.get(server_name, {}).get(key_id, {})
req_valid_until = req_key.get("minimum_valid_until_ts") req_valid_until = req_key.get("minimum_valid_until_ts")
if req_valid_until is not None: if req_valid_until is not None:
@ -235,8 +241,8 @@ class RemoteKey(RestServlet):
ts_valid_until_ms, ts_valid_until_ms,
time_now_ms, time_now_ms,
) )
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(most_recent_result["key_json"])) json_results.add(key_result.key_json)
if miss and query_remote_on_cache_miss: if miss and query_remote_on_cache_miss:
# only bother attempting to fetch keys from servers on our whitelist # only bother attempting to fetch keys from servers on our whitelist

View file

@ -16,14 +16,13 @@
import itertools import itertools
import json import json
import logging import logging
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple from typing import Dict, Iterable, Mapping, Optional, Tuple
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
from synapse.storage._base import SQLBaseStore from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.database import LoggingTransaction from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote
from synapse.storage.keys import FetchKeyResult
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
@ -34,7 +33,7 @@ logger = logging.getLogger(__name__)
db_binary_type = memoryview db_binary_type = memoryview
class KeyStore(SQLBaseStore): class KeyStore(CacheInvalidationWorkerStore):
"""Persistence for signature verification keys""" """Persistence for signature verification keys"""
@cached() @cached()
@ -188,7 +187,12 @@ class KeyStore(SQLBaseStore):
# invalidate takes a tuple corresponding to the params of # invalidate takes a tuple corresponding to the params of
# _get_server_keys_json. _get_server_keys_json only takes one # _get_server_keys_json. _get_server_keys_json only takes one
# param, which is itself the 2-tuple (server_name, key_id). # param, which is itself the 2-tuple (server_name, key_id).
self._get_server_keys_json.invalidate(((server_name, key_id),)) await self.invalidate_cache_and_stream(
"_get_server_keys_json", ((server_name, key_id),)
)
await self.invalidate_cache_and_stream(
"get_server_key_json_for_remote", (server_name, key_id)
)
@cached() @cached()
def _get_server_keys_json( def _get_server_keys_json(
@ -253,47 +257,87 @@ class KeyStore(SQLBaseStore):
return await self.db_pool.runInteraction("get_server_keys_json", _txn) return await self.db_pool.runInteraction("get_server_keys_json", _txn)
@cached()
def get_server_key_json_for_remote(
self,
server_name: str,
key_id: str,
) -> Optional[FetchKeyResultForRemote]:
raise NotImplementedError()
@cachedList(
cached_method_name="get_server_key_json_for_remote", list_name="key_ids"
)
async def get_server_keys_json_for_remote( async def get_server_keys_json_for_remote(
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]] self, server_name: str, key_ids: Iterable[str]
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]: ) -> Dict[str, Optional[FetchKeyResultForRemote]]:
"""Retrieve the key json for a list of server_keys and key ids. """Fetch the cached keys for the given server/key IDs.
If no keys are found for a given server, key_id and source then
that server, key_id, and source triplet entry will be an empty list.
The JSON is returned as a byte array so that it can be efficiently
used in an HTTP response.
Args: If we have multiple entries for a given key ID, returns the most recent.
server_keys: List of (server_name, key_id, source) triplets.
Returns:
A mapping from (server_name, key_id, source) triplets to a list of dicts
""" """
rows = await self.db_pool.simple_select_many_batch(
def _get_server_keys_json_txn( table="server_keys_json",
txn: LoggingTransaction, column="key_id",
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]: iterable=key_ids,
results = {} keyvalues={"server_name": server_name},
for server_name, key_id, from_server in server_keys: retcols=(
keyvalues = {"server_name": server_name} "key_id",
if key_id is not None: "from_server",
keyvalues["key_id"] = key_id "ts_added_ms",
if from_server is not None: "ts_valid_until_ms",
keyvalues["from_server"] = from_server "key_json",
rows = self.db_pool.simple_select_list_txn( ),
txn, desc="get_server_keys_json_for_remote",
"server_keys_json",
keyvalues=keyvalues,
retcols=(
"key_id",
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
),
)
results[(server_name, key_id, from_server)] = rows
return results
return await self.db_pool.runInteraction(
"get_server_keys_json", _get_server_keys_json_txn
) )
if not rows:
return {}
# We sort the rows so that the most recently added entry is picked up.
rows.sort(key=lambda r: r["ts_added_ms"])
return {
row["key_id"]: FetchKeyResultForRemote(
# Cast to bytes since postgresql returns a memoryview.
key_json=bytes(row["key_json"]),
valid_until_ts=row["ts_valid_until_ms"],
added_ts=row["ts_added_ms"],
)
for row in rows
}
async def get_all_server_keys_json_for_remote(
self,
server_name: str,
) -> Dict[str, FetchKeyResultForRemote]:
"""Fetch the cached keys for the given server.
If we have multiple entries for a given key ID, returns the most recent.
"""
rows = await self.db_pool.simple_select_list(
table="server_keys_json",
keyvalues={"server_name": server_name},
retcols=(
"key_id",
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
),
desc="get_server_keys_json_for_remote",
)
if not rows:
return {}
rows.sort(key=lambda r: r["ts_added_ms"])
return {
row["key_id"]: FetchKeyResultForRemote(
# Cast to bytes since postgresql returns a memoryview.
key_json=bytes(row["key_json"]),
valid_until_ts=row["ts_valid_until_ms"],
added_ts=row["ts_added_ms"],
)
for row in rows
}

View file

@ -25,3 +25,10 @@ logger = logging.getLogger(__name__)
class FetchKeyResult: class FetchKeyResult:
verify_key: VerifyKey # the key itself verify_key: VerifyKey # the key itself
valid_until_ts: int # how long we can use this key for valid_until_ts: int # how long we can use this key for
@attr.s(slots=True, frozen=True, auto_attribs=True)
class FetchKeyResultForRemote:
key_json: bytes # the full key JSON
valid_until_ts: int # how long we can use this key for, in milliseconds.
added_ts: int # When we added this key, in milliseconds.

View file

@ -456,24 +456,19 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(k.verify_key.version, "ver1") self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated # check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success( key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote( self.hs.get_datastores().main.get_server_keys_json_for_remote(
[lookup_triplet] SERVER_NAME, [testverifykey_id]
) )
) )
res_keys = key_json[lookup_triplet] res = key_json[testverifykey_id]
self.assertEqual(len(res_keys), 1) self.assertIsNotNone(res)
res = res_keys[0] assert res is not None
self.assertEqual(res["key_id"], testverifykey_id) self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
self.assertEqual(res["from_server"], SERVER_NAME) self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
# we expect it to be encoded as canonical json *before* it hits the db # we expect it to be encoded as canonical json *before* it hits the db
self.assertEqual( self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
# change the server name: the result should be ignored # change the server name: the result should be ignored
response["server_name"] = "OTHER_SERVER" response["server_name"] = "OTHER_SERVER"
@ -576,23 +571,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(k.verify_key.version, "ver1") self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated # check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success( key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote( self.hs.get_datastores().main.get_server_keys_json_for_remote(
[lookup_triplet] SERVER_NAME, [testverifykey_id]
) )
) )
res_keys = key_json[lookup_triplet] res = key_json[testverifykey_id]
self.assertEqual(len(res_keys), 1) self.assertIsNotNone(res)
res = res_keys[0] assert res is not None
self.assertEqual(res["key_id"], testverifykey_id) self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name) self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
self.assertEqual( self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
def test_get_multiple_keys_from_perspectives(self) -> None: def test_get_multiple_keys_from_perspectives(self) -> None:
"""Check that we can correctly request multiple keys for the same server""" """Check that we can correctly request multiple keys for the same server"""
@ -699,23 +689,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(k.verify_key.version, "ver1") self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated # check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success( key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote( self.hs.get_datastores().main.get_server_keys_json_for_remote(
[lookup_triplet] SERVER_NAME, [testverifykey_id]
) )
) )
res_keys = key_json[lookup_triplet] res = key_json[testverifykey_id]
self.assertEqual(len(res_keys), 1) self.assertIsNotNone(res)
res = res_keys[0] assert res is not None
self.assertEqual(res["key_id"], testverifykey_id) self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name) self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
self.assertEqual( self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
def test_invalid_perspectives_responses(self) -> None: def test_invalid_perspectives_responses(self) -> None:
"""Check that invalid responses from the perspectives server are rejected""" """Check that invalid responses from the perspectives server are rejected"""