mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-24 10:35:46 +03:00
Add cache to get_server_keys_json_for_remote
(#16123)
This commit is contained in:
parent
54a51ff6c1
commit
0aba4a4eaa
5 changed files with 144 additions and 101 deletions
1
changelog.d/16123.misc
Normal file
1
changelog.d/16123.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add cache to `get_server_keys_json_for_remote`.
|
|
@ -14,7 +14,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple
|
||||||
|
|
||||||
from signedjson.sign import sign_json
|
from signedjson.sign import sign_json
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ from synapse.http.servlet import (
|
||||||
parse_integer,
|
parse_integer,
|
||||||
parse_json_object_from_request,
|
parse_json_object_from_request,
|
||||||
)
|
)
|
||||||
|
from synapse.storage.keys import FetchKeyResultForRemote
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
from synapse.util.async_helpers import yieldable_gather_results
|
from synapse.util.async_helpers import yieldable_gather_results
|
||||||
|
@ -157,14 +158,22 @@ class RemoteKey(RestServlet):
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
logger.info("Handling query for keys %r", query)
|
logger.info("Handling query for keys %r", query)
|
||||||
|
|
||||||
store_queries = []
|
server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {}
|
||||||
for server_name, key_ids in query.items():
|
for server_name, key_ids in query.items():
|
||||||
if not key_ids:
|
if key_ids:
|
||||||
key_ids = (None,)
|
results: Mapping[
|
||||||
for key_id in key_ids:
|
str, Optional[FetchKeyResultForRemote]
|
||||||
store_queries.append((server_name, key_id, None))
|
] = await self.store.get_server_keys_json_for_remote(
|
||||||
|
server_name, key_ids
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results = await self.store.get_all_server_keys_json_for_remote(
|
||||||
|
server_name
|
||||||
|
)
|
||||||
|
|
||||||
cached = await self.store.get_server_keys_json_for_remote(store_queries)
|
server_keys.update(
|
||||||
|
((server_name, key_id), res) for key_id, res in results.items()
|
||||||
|
)
|
||||||
|
|
||||||
json_results: Set[bytes] = set()
|
json_results: Set[bytes] = set()
|
||||||
|
|
||||||
|
@ -173,23 +182,20 @@ class RemoteKey(RestServlet):
|
||||||
# Map server_name->key_id->int. Note that the value of the int is unused.
|
# Map server_name->key_id->int. Note that the value of the int is unused.
|
||||||
# XXX: why don't we just use a set?
|
# XXX: why don't we just use a set?
|
||||||
cache_misses: Dict[str, Dict[str, int]] = {}
|
cache_misses: Dict[str, Dict[str, int]] = {}
|
||||||
for (server_name, key_id, _), key_results in cached.items():
|
for (server_name, key_id), key_result in server_keys.items():
|
||||||
results = [(result["ts_added_ms"], result) for result in key_results]
|
if not query[server_name]:
|
||||||
|
|
||||||
if key_id is None:
|
|
||||||
# all keys were requested. Just return what we have without worrying
|
# all keys were requested. Just return what we have without worrying
|
||||||
# about validity
|
# about validity
|
||||||
for _, result in results:
|
if key_result:
|
||||||
# Cast to bytes since postgresql returns a memoryview.
|
json_results.add(key_result.key_json)
|
||||||
json_results.add(bytes(result["key_json"]))
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
miss = False
|
miss = False
|
||||||
if not results:
|
if key_result is None:
|
||||||
miss = True
|
miss = True
|
||||||
else:
|
else:
|
||||||
ts_added_ms, most_recent_result = max(results)
|
ts_added_ms = key_result.added_ts
|
||||||
ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
|
ts_valid_until_ms = key_result.valid_until_ts
|
||||||
req_key = query.get(server_name, {}).get(key_id, {})
|
req_key = query.get(server_name, {}).get(key_id, {})
|
||||||
req_valid_until = req_key.get("minimum_valid_until_ts")
|
req_valid_until = req_key.get("minimum_valid_until_ts")
|
||||||
if req_valid_until is not None:
|
if req_valid_until is not None:
|
||||||
|
@ -235,8 +241,8 @@ class RemoteKey(RestServlet):
|
||||||
ts_valid_until_ms,
|
ts_valid_until_ms,
|
||||||
time_now_ms,
|
time_now_ms,
|
||||||
)
|
)
|
||||||
# Cast to bytes since postgresql returns a memoryview.
|
|
||||||
json_results.add(bytes(most_recent_result["key_json"]))
|
json_results.add(key_result.key_json)
|
||||||
|
|
||||||
if miss and query_remote_on_cache_miss:
|
if miss and query_remote_on_cache_miss:
|
||||||
# only bother attempting to fetch keys from servers on our whitelist
|
# only bother attempting to fetch keys from servers on our whitelist
|
||||||
|
|
|
@ -16,14 +16,13 @@
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
|
from typing import Dict, Iterable, Mapping, Optional, Tuple
|
||||||
|
|
||||||
from signedjson.key import decode_verify_key_bytes
|
from signedjson.key import decode_verify_key_bytes
|
||||||
from unpaddedbase64 import decode_base64
|
from unpaddedbase64 import decode_base64
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||||
from synapse.storage.database import LoggingTransaction
|
from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote
|
||||||
from synapse.storage.keys import FetchKeyResult
|
|
||||||
from synapse.storage.types import Cursor
|
from synapse.storage.types import Cursor
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
|
@ -34,7 +33,7 @@ logger = logging.getLogger(__name__)
|
||||||
db_binary_type = memoryview
|
db_binary_type = memoryview
|
||||||
|
|
||||||
|
|
||||||
class KeyStore(SQLBaseStore):
|
class KeyStore(CacheInvalidationWorkerStore):
|
||||||
"""Persistence for signature verification keys"""
|
"""Persistence for signature verification keys"""
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
|
@ -188,7 +187,12 @@ class KeyStore(SQLBaseStore):
|
||||||
# invalidate takes a tuple corresponding to the params of
|
# invalidate takes a tuple corresponding to the params of
|
||||||
# _get_server_keys_json. _get_server_keys_json only takes one
|
# _get_server_keys_json. _get_server_keys_json only takes one
|
||||||
# param, which is itself the 2-tuple (server_name, key_id).
|
# param, which is itself the 2-tuple (server_name, key_id).
|
||||||
self._get_server_keys_json.invalidate(((server_name, key_id),))
|
await self.invalidate_cache_and_stream(
|
||||||
|
"_get_server_keys_json", ((server_name, key_id),)
|
||||||
|
)
|
||||||
|
await self.invalidate_cache_and_stream(
|
||||||
|
"get_server_key_json_for_remote", (server_name, key_id)
|
||||||
|
)
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
def _get_server_keys_json(
|
def _get_server_keys_json(
|
||||||
|
@ -253,47 +257,87 @@ class KeyStore(SQLBaseStore):
|
||||||
|
|
||||||
return await self.db_pool.runInteraction("get_server_keys_json", _txn)
|
return await self.db_pool.runInteraction("get_server_keys_json", _txn)
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
def get_server_key_json_for_remote(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
key_id: str,
|
||||||
|
) -> Optional[FetchKeyResultForRemote]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@cachedList(
|
||||||
|
cached_method_name="get_server_key_json_for_remote", list_name="key_ids"
|
||||||
|
)
|
||||||
async def get_server_keys_json_for_remote(
|
async def get_server_keys_json_for_remote(
|
||||||
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
|
self, server_name: str, key_ids: Iterable[str]
|
||||||
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
|
) -> Dict[str, Optional[FetchKeyResultForRemote]]:
|
||||||
"""Retrieve the key json for a list of server_keys and key ids.
|
"""Fetch the cached keys for the given server/key IDs.
|
||||||
If no keys are found for a given server, key_id and source then
|
|
||||||
that server, key_id, and source triplet entry will be an empty list.
|
|
||||||
The JSON is returned as a byte array so that it can be efficiently
|
|
||||||
used in an HTTP response.
|
|
||||||
|
|
||||||
Args:
|
If we have multiple entries for a given key ID, returns the most recent.
|
||||||
server_keys: List of (server_name, key_id, source) triplets.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A mapping from (server_name, key_id, source) triplets to a list of dicts
|
|
||||||
"""
|
"""
|
||||||
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
def _get_server_keys_json_txn(
|
table="server_keys_json",
|
||||||
txn: LoggingTransaction,
|
column="key_id",
|
||||||
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
|
iterable=key_ids,
|
||||||
results = {}
|
keyvalues={"server_name": server_name},
|
||||||
for server_name, key_id, from_server in server_keys:
|
retcols=(
|
||||||
keyvalues = {"server_name": server_name}
|
"key_id",
|
||||||
if key_id is not None:
|
"from_server",
|
||||||
keyvalues["key_id"] = key_id
|
"ts_added_ms",
|
||||||
if from_server is not None:
|
"ts_valid_until_ms",
|
||||||
keyvalues["from_server"] = from_server
|
"key_json",
|
||||||
rows = self.db_pool.simple_select_list_txn(
|
),
|
||||||
txn,
|
desc="get_server_keys_json_for_remote",
|
||||||
"server_keys_json",
|
|
||||||
keyvalues=keyvalues,
|
|
||||||
retcols=(
|
|
||||||
"key_id",
|
|
||||||
"from_server",
|
|
||||||
"ts_added_ms",
|
|
||||||
"ts_valid_until_ms",
|
|
||||||
"key_json",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
results[(server_name, key_id, from_server)] = rows
|
|
||||||
return results
|
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
|
||||||
"get_server_keys_json", _get_server_keys_json_txn
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# We sort the rows so that the most recently added entry is picked up.
|
||||||
|
rows.sort(key=lambda r: r["ts_added_ms"])
|
||||||
|
|
||||||
|
return {
|
||||||
|
row["key_id"]: FetchKeyResultForRemote(
|
||||||
|
# Cast to bytes since postgresql returns a memoryview.
|
||||||
|
key_json=bytes(row["key_json"]),
|
||||||
|
valid_until_ts=row["ts_valid_until_ms"],
|
||||||
|
added_ts=row["ts_added_ms"],
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_all_server_keys_json_for_remote(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
) -> Dict[str, FetchKeyResultForRemote]:
|
||||||
|
"""Fetch the cached keys for the given server.
|
||||||
|
|
||||||
|
If we have multiple entries for a given key ID, returns the most recent.
|
||||||
|
"""
|
||||||
|
rows = await self.db_pool.simple_select_list(
|
||||||
|
table="server_keys_json",
|
||||||
|
keyvalues={"server_name": server_name},
|
||||||
|
retcols=(
|
||||||
|
"key_id",
|
||||||
|
"from_server",
|
||||||
|
"ts_added_ms",
|
||||||
|
"ts_valid_until_ms",
|
||||||
|
"key_json",
|
||||||
|
),
|
||||||
|
desc="get_server_keys_json_for_remote",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
rows.sort(key=lambda r: r["ts_added_ms"])
|
||||||
|
|
||||||
|
return {
|
||||||
|
row["key_id"]: FetchKeyResultForRemote(
|
||||||
|
# Cast to bytes since postgresql returns a memoryview.
|
||||||
|
key_json=bytes(row["key_json"]),
|
||||||
|
valid_until_ts=row["ts_valid_until_ms"],
|
||||||
|
added_ts=row["ts_added_ms"],
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
}
|
||||||
|
|
|
@ -25,3 +25,10 @@ logger = logging.getLogger(__name__)
|
||||||
class FetchKeyResult:
|
class FetchKeyResult:
|
||||||
verify_key: VerifyKey # the key itself
|
verify_key: VerifyKey # the key itself
|
||||||
valid_until_ts: int # how long we can use this key for
|
valid_until_ts: int # how long we can use this key for
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
class FetchKeyResultForRemote:
|
||||||
|
key_json: bytes # the full key JSON
|
||||||
|
valid_until_ts: int # how long we can use this key for, in milliseconds.
|
||||||
|
added_ts: int # When we added this key, in milliseconds.
|
||||||
|
|
|
@ -456,24 +456,19 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(k.verify_key.version, "ver1")
|
self.assertEqual(k.verify_key.version, "ver1")
|
||||||
|
|
||||||
# check that the perspectives store is correctly updated
|
# check that the perspectives store is correctly updated
|
||||||
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
|
|
||||||
key_json = self.get_success(
|
key_json = self.get_success(
|
||||||
self.hs.get_datastores().main.get_server_keys_json_for_remote(
|
self.hs.get_datastores().main.get_server_keys_json_for_remote(
|
||||||
[lookup_triplet]
|
SERVER_NAME, [testverifykey_id]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
res_keys = key_json[lookup_triplet]
|
res = key_json[testverifykey_id]
|
||||||
self.assertEqual(len(res_keys), 1)
|
self.assertIsNotNone(res)
|
||||||
res = res_keys[0]
|
assert res is not None
|
||||||
self.assertEqual(res["key_id"], testverifykey_id)
|
self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
|
||||||
self.assertEqual(res["from_server"], SERVER_NAME)
|
self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
|
||||||
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
|
|
||||||
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
|
|
||||||
|
|
||||||
# we expect it to be encoded as canonical json *before* it hits the db
|
# we expect it to be encoded as canonical json *before* it hits the db
|
||||||
self.assertEqual(
|
self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
|
||||||
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
|
|
||||||
)
|
|
||||||
|
|
||||||
# change the server name: the result should be ignored
|
# change the server name: the result should be ignored
|
||||||
response["server_name"] = "OTHER_SERVER"
|
response["server_name"] = "OTHER_SERVER"
|
||||||
|
@ -576,23 +571,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(k.verify_key.version, "ver1")
|
self.assertEqual(k.verify_key.version, "ver1")
|
||||||
|
|
||||||
# check that the perspectives store is correctly updated
|
# check that the perspectives store is correctly updated
|
||||||
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
|
|
||||||
key_json = self.get_success(
|
key_json = self.get_success(
|
||||||
self.hs.get_datastores().main.get_server_keys_json_for_remote(
|
self.hs.get_datastores().main.get_server_keys_json_for_remote(
|
||||||
[lookup_triplet]
|
SERVER_NAME, [testverifykey_id]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
res_keys = key_json[lookup_triplet]
|
res = key_json[testverifykey_id]
|
||||||
self.assertEqual(len(res_keys), 1)
|
self.assertIsNotNone(res)
|
||||||
res = res_keys[0]
|
assert res is not None
|
||||||
self.assertEqual(res["key_id"], testverifykey_id)
|
self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
|
||||||
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
|
self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
|
||||||
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
|
|
||||||
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
|
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
|
||||||
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_get_multiple_keys_from_perspectives(self) -> None:
|
def test_get_multiple_keys_from_perspectives(self) -> None:
|
||||||
"""Check that we can correctly request multiple keys for the same server"""
|
"""Check that we can correctly request multiple keys for the same server"""
|
||||||
|
@ -699,23 +689,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(k.verify_key.version, "ver1")
|
self.assertEqual(k.verify_key.version, "ver1")
|
||||||
|
|
||||||
# check that the perspectives store is correctly updated
|
# check that the perspectives store is correctly updated
|
||||||
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
|
|
||||||
key_json = self.get_success(
|
key_json = self.get_success(
|
||||||
self.hs.get_datastores().main.get_server_keys_json_for_remote(
|
self.hs.get_datastores().main.get_server_keys_json_for_remote(
|
||||||
[lookup_triplet]
|
SERVER_NAME, [testverifykey_id]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
res_keys = key_json[lookup_triplet]
|
res = key_json[testverifykey_id]
|
||||||
self.assertEqual(len(res_keys), 1)
|
self.assertIsNotNone(res)
|
||||||
res = res_keys[0]
|
assert res is not None
|
||||||
self.assertEqual(res["key_id"], testverifykey_id)
|
self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
|
||||||
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
|
self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
|
||||||
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
|
|
||||||
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
|
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
|
||||||
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_invalid_perspectives_responses(self) -> None:
|
def test_invalid_perspectives_responses(self) -> None:
|
||||||
"""Check that invalid responses from the perspectives server are rejected"""
|
"""Check that invalid responses from the perspectives server are rejected"""
|
||||||
|
|
Loading…
Reference in a new issue