Update the device list cache when keys/query is called

We only cache it when it makes sense to do so.
This commit is contained in:
Jorik Schellekens 2019-07-16 17:01:18 +01:00
parent d86321300a
commit 700f1ca681
2 changed files with 231 additions and 77 deletions

View file

@ -18,6 +18,7 @@ from six import iteritems, itervalues
from twisted.internet import defer from twisted.internet import defer
import synapse.logging.opentracing as opentracing
from synapse.api import errors from synapse.api import errors
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import ( from synapse.api.errors import (
@ -211,12 +212,12 @@ class DeviceHandler(DeviceWorkerHandler):
self.federation_sender = hs.get_federation_sender() self.federation_sender = hs.get_federation_sender()
self._edu_updater = DeviceListEduUpdater(hs, self) self.device_list_updater = DeviceListUpdater(hs, self)
federation_registry = hs.get_federation_registry() federation_registry = hs.get_federation_registry()
federation_registry.register_edu_handler( federation_registry.register_edu_handler(
"m.device_list_update", self._edu_updater.incoming_device_list_update "m.device_list_update", self.device_list_updater.incoming_device_list_update
) )
federation_registry.register_query_handler( federation_registry.register_query_handler(
"user_devices", self.on_federation_query_user_devices "user_devices", self.on_federation_query_user_devices
@ -430,7 +431,7 @@ def _update_device_from_client_ips(device, client_ips):
device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
class DeviceListEduUpdater(object): class DeviceListUpdater(object):
"Handles incoming device list updates from federation and updates the DB" "Handles incoming device list updates from federation and updates the DB"
def __init__(self, hs, device_handler): def __init__(self, hs, device_handler):
@ -523,75 +524,7 @@ class DeviceListEduUpdater(object):
logger.debug("Need to re-sync devices for %r? %r", user_id, resync) logger.debug("Need to re-sync devices for %r? %r", user_id, resync)
if resync: if resync:
# Fetch all devices for the user. yield self.user_device_resync(user_id)
origin = get_domain_from_id(user_id)
try:
result = yield self.federation.query_user_devices(origin, user_id)
except (
NotRetryingDestination,
RequestSendFailed,
HttpResponseException,
):
# TODO: Remember that we are now out of sync and try again
# later
logger.warn("Failed to handle device list update for %s", user_id)
# We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list
# is out of date. If we bail then we will retry the resync
# next time we get a device list update for this user_id.
# This makes it more likely that the device lists will
# eventually become consistent.
return
except FederationDeniedError as e:
logger.info(e)
return
except Exception:
# TODO: Remember that we are now out of sync and try again
# later
logger.exception(
"Failed to handle device list update for %s", user_id
)
return
stream_id = result["stream_id"]
devices = result["devices"]
# If the remote server has more than ~1000 devices for this user
# we assume that something is going horribly wrong (e.g. a bot
# that logs in and creates a new device every time it tries to
# send a message). Maintaining lots of devices per user in the
# cache can cause serious performance issues as if this request
# takes more than 60s to complete, internal replication from the
# inbound federation worker to the synapse master may time out
# causing the inbound federation to fail and causing the remote
# server to retry, causing a DoS. So in this scenario we give
# up on storing the total list of devices and only handle the
# delta instead.
if len(devices) > 1000:
logger.warn(
"Ignoring device list snapshot for %s as it has >1K devs (%d)",
user_id,
len(devices),
)
devices = []
for device in devices:
logger.debug(
"Handling resync update %r/%r, ID: %r",
user_id,
device["device_id"],
stream_id,
)
yield self.store.update_remote_device_list_cache(
user_id, devices, stream_id
)
device_ids = [device["device_id"] for device in devices]
yield self.device_handler.notify_device_update(user_id, device_ids)
# We clobber the seen updates since we've re-synced from a given
# point.
self._seen_updates[user_id] = set([stream_id])
else: else:
# Simply update the single device, since we know that is the only # Simply update the single device, since we know that is the only
# change (because of the single prev_id matching the current cache) # change (because of the single prev_id matching the current cache)
@ -638,3 +571,86 @@ class DeviceListEduUpdater(object):
stream_id_in_updates.add(stream_id) stream_id_in_updates.add(stream_id)
defer.returnValue(False) defer.returnValue(False)
@opentracing.trace_deferred
@defer.inlineCallbacks
def user_device_resync(self, user_id):
"""Fetches all devices for a user and updates the device cache with them.
Args:
user_id (String): The user's id whose device_list will be updated.
Returns:
a dict with device info as under the "devices" in the result of this
request:
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
"""
opentracing.log_kv({"message": "Doing resync to update device list."})
# Fetch all devices for the user.
origin = get_domain_from_id(user_id)
try:
result = yield self.federation.query_user_devices(origin, user_id)
except (NotRetryingDestination, RequestSendFailed, HttpResponseException):
# TODO: Remember that we are now out of sync and try again
# later
logger.warn("Failed to handle device list update for %s", user_id)
# We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list
# is out of date. If we bail then we will retry the resync
# next time we get a device list update for this user_id.
# This makes it more likely that the device lists will
# eventually become consistent.
return
except FederationDeniedError as e:
opentracing.set_tag("error", True)
opentracing.log_kv({"reason": "FederationDeniedError"})
logger.info(e)
return
except Exception as e:
# TODO: Remember that we are now out of sync and try again
# later
opentracing.set_tag("error", True)
opentracing.log_kv(
{"message": "Exception raised by federation request", "exception": e}
)
logger.exception("Failed to handle device list update for %s", user_id)
return
opentracing.log_kv({"result": result})
stream_id = result["stream_id"]
devices = result["devices"]
# If the remote server has more than ~1000 devices for this user
# we assume that something is going horribly wrong (e.g. a bot
# that logs in and creates a new device every time it tries to
# send a message). Maintaining lots of devices per user in the
# cache can cause serious performance issues as if this request
# takes more than 60s to complete, internal replication from the
# inbound federation worker to the synapse master may time out
# causing the inbound federation to fail and causing the remote
# server to retry, causing a DoS. So in this scenario we give
# up on storing the total list of devices and only handle the
# delta instead.
if len(devices) > 1000:
logger.warn(
"Ignoring device list snapshot for %s as it has >1K devs (%d)",
user_id,
len(devices),
)
devices = []
for device in devices:
logger.debug(
"Handling resync update %r/%r, ID: %r",
user_id,
device["device_id"],
stream_id,
)
yield self.store.update_remote_device_list_cache(user_id, devices, stream_id)
device_ids = [device["device_id"] for device in devices]
yield self.device_handler.notify_device_update(user_id, device_ids)
# We clobber the seen updates since we've re-synced from a given
# point.
self._seen_updates[user_id] = set([stream_id])
defer.returnValue(result)

View file

@ -22,8 +22,14 @@ from canonicaljson import encode_canonical_json, json
from twisted.internet import defer from twisted.internet import defer
import synapse.logging.opentracing as opentracing
from synapse.api.errors import CodeMessageException, SynapseError from synapse.api.errors import CodeMessageException, SynapseError
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
make_deferred_yieldable,
run_in_background,
)
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID, get_domain_from_id
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
@ -45,6 +51,7 @@ class E2eKeysHandler(object):
"client_keys", self.on_federation_query_client_keys "client_keys", self.on_federation_query_client_keys
) )
@opentracing.trace_deferred
@defer.inlineCallbacks @defer.inlineCallbacks
def query_devices(self, query_body, timeout): def query_devices(self, query_body, timeout):
""" Handle a device key query from a client """ Handle a device key query from a client
@ -65,6 +72,7 @@ class E2eKeysHandler(object):
} }
} }
""" """
device_keys_query = query_body.get("device_keys", {}) device_keys_query = query_body.get("device_keys", {})
# separate users by domain. # separate users by domain.
@ -79,6 +87,9 @@ class E2eKeysHandler(object):
else: else:
remote_queries[user_id] = device_ids remote_queries[user_id] = device_ids
opentracing.set_tag("local_key_query", local_query)
opentracing.set_tag("remote_key_query", remote_queries)
# First get local devices. # First get local devices.
failures = {} failures = {}
results = {} results = {}
@ -119,9 +130,81 @@ class E2eKeysHandler(object):
r[user_id] = remote_queries[user_id] r[user_id] = remote_queries[user_id]
# Now fetch any devices that we don't have in our cache # Now fetch any devices that we don't have in our cache
@opentracing.trace_deferred
@defer.inlineCallbacks @defer.inlineCallbacks
def do_remote_query(destination): def do_remote_query(destination):
destination_query = remote_queries_not_in_cache[destination] destination_query = remote_queries_not_in_cache[destination]
opentracing.set_tag("key_query", destination_query)
# We first consider whether we wish to update the dive list cache with
# the users device list. We want to track a user's devices when the
# authenticated user shares a room with the queried user and the query
# has not specified a particular device.
# If we update the cache for the queried user we remove them from further
# queries. We use the more efficient batched query_client_keys for all
# remaining users
user_ids_updated = []
for (user_id, device_list) in destination_query.items():
if user_id not in user_ids_updated:
try:
with PreserveLoggingContext(LoggingContext.current_context()):
room_ids = yield self.store.get_rooms_for_user(user_id)
if not device_list and room_ids:
opentracing.log_kv(
{
"message": "Resyncing devices for user",
"user_id": user_id,
}
)
user_devices = yield self.device_handler.device_list_updater.user_device_resync(
user_id
)
user_devices = user_devices["devices"]
opentracing.log_kv(
{
"message": "got user devices",
"user_devices": user_devices,
}
)
for device in user_devices:
results[user_id] = {
device["device_id"]: device["keys"]
}
opentracing.log_kv(
{"adding user to user_ids_updated": user_id}
)
user_ids_updated.append(user_id)
else:
opentracing.log_kv(
{
"message": "Not resyncing devices for user",
"user_id": user_id,
}
)
except Exception as e:
failures[destination] = failures.get(destination, []).append(
_exception_to_failure(e)
)
opentracing.set_tag("error", True)
opentracing.log_kv({"exception": e})
if len(destination_query) == len(user_ids_updated):
# We've updated all the users in the query and we do not need to
# make any further remote calls.
return
# Remove all the users from the query which we have updated
for user_id in user_ids_updated:
destination_query.pop(user_id)
opentracing.log_kv(
{
"message": "Querying remote servers for keys",
"destination_query": destination_query,
"not querying": user_ids_updated,
}
)
try: try:
remote_result = yield self.federation.query_client_keys( remote_result = yield self.federation.query_client_keys(
destination, {"device_keys": destination_query}, timeout=timeout destination, {"device_keys": destination_query}, timeout=timeout
@ -132,7 +215,8 @@ class E2eKeysHandler(object):
results[user_id] = keys results[user_id] = keys
except Exception as e: except Exception as e:
failures[destination] = _exception_to_failure(e) failure = _exception_to_failure(e)
failures[destination] = failure
yield make_deferred_yieldable( yield make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
@ -143,9 +227,10 @@ class E2eKeysHandler(object):
consumeErrors=True, consumeErrors=True,
) )
) )
opentracing.log_kv({"device_keys": results, "failures": failures})
defer.returnValue({"device_keys": results, "failures": failures}) defer.returnValue({"device_keys": results, "failures": failures})
@opentracing.trace_deferred
@defer.inlineCallbacks @defer.inlineCallbacks
def query_local_devices(self, query): def query_local_devices(self, query):
"""Get E2E device keys for local users """Get E2E device keys for local users
@ -158,6 +243,7 @@ class E2eKeysHandler(object):
defer.Deferred: (resolves to dict[string, dict[string, dict]]): defer.Deferred: (resolves to dict[string, dict[string, dict]]):
map from user_id -> device_id -> device details map from user_id -> device_id -> device details
""" """
opentracing.set_tag("local_query", query)
local_query = [] local_query = []
result_dict = {} result_dict = {}
@ -165,6 +251,14 @@ class E2eKeysHandler(object):
# we use UserID.from_string to catch invalid user ids # we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)): if not self.is_mine(UserID.from_string(user_id)):
logger.warning("Request for keys for non-local user %s", user_id) logger.warning("Request for keys for non-local user %s", user_id)
opentracing.log_kv(
{
"message": "Requested a local key for a user which"
+ " was not local to the homeserver",
"user_id": user_id,
}
)
opentracing.set_tag("error", True)
raise SynapseError(400, "Not a user here") raise SynapseError(400, "Not a user here")
if not device_ids: if not device_ids:
@ -189,6 +283,7 @@ class E2eKeysHandler(object):
r["unsigned"]["device_display_name"] = display_name r["unsigned"]["device_display_name"] = display_name
result_dict[user_id][device_id] = r result_dict[user_id][device_id] = r
opentracing.log_kv(results)
defer.returnValue(result_dict) defer.returnValue(result_dict)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -199,6 +294,7 @@ class E2eKeysHandler(object):
res = yield self.query_local_devices(device_keys_query) res = yield self.query_local_devices(device_keys_query)
defer.returnValue({"device_keys": res}) defer.returnValue({"device_keys": res})
@opentracing.trace_deferred
@defer.inlineCallbacks @defer.inlineCallbacks
def claim_one_time_keys(self, query, timeout): def claim_one_time_keys(self, query, timeout):
local_query = [] local_query = []
@ -213,6 +309,9 @@ class E2eKeysHandler(object):
domain = get_domain_from_id(user_id) domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = device_keys remote_queries.setdefault(domain, {})[user_id] = device_keys
opentracing.set_tag("local_key_query", local_query)
opentracing.set_tag("remote_key_query", remote_queries)
results = yield self.store.claim_e2e_one_time_keys(local_query) results = yield self.store.claim_e2e_one_time_keys(local_query)
json_result = {} json_result = {}
@ -224,8 +323,10 @@ class E2eKeysHandler(object):
key_id: json.loads(json_bytes) key_id: json.loads(json_bytes)
} }
@opentracing.trace_deferred
@defer.inlineCallbacks @defer.inlineCallbacks
def claim_client_keys(destination): def claim_client_keys(destination):
opentracing.set_tag("destination", destination)
device_keys = remote_queries[destination] device_keys = remote_queries[destination]
try: try:
remote_result = yield self.federation.claim_client_keys( remote_result = yield self.federation.claim_client_keys(
@ -234,8 +335,12 @@ class E2eKeysHandler(object):
for user_id, keys in remote_result["one_time_keys"].items(): for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys: if user_id in device_keys:
json_result[user_id] = keys json_result[user_id] = keys
except Exception as e: except Exception as e:
failures[destination] = _exception_to_failure(e) failure = _exception_to_failure(e)
failures[destination] = failure
opentracing.set_tag("error", True)
opentracing.set_tag("reason", failure)
yield make_deferred_yieldable( yield make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
@ -259,14 +364,21 @@ class E2eKeysHandler(object):
), ),
) )
opentracing.log_kv({"one_time_keys": json_result, "failures": failures})
defer.returnValue({"one_time_keys": json_result, "failures": failures}) defer.returnValue({"one_time_keys": json_result, "failures": failures})
@opentracing.trace_deferred
@opentracing.tag_args
@defer.inlineCallbacks @defer.inlineCallbacks
@opentracing.tag_args
def upload_keys_for_user(self, user_id, device_id, keys): def upload_keys_for_user(self, user_id, device_id, keys):
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
# TODO: Validate the JSON to make sure it has the right keys. # TODO: Validate the JSON to make sure it has the right keys.
device_keys = keys.get("device_keys", None) device_keys = keys.get("device_keys", None)
opentracing.set_tag("device_keys", device_keys)
if device_keys: if device_keys:
logger.info( logger.info(
"Updating device_keys for device %r for user %s at %d", "Updating device_keys for device %r for user %s at %d",
@ -274,6 +386,13 @@ class E2eKeysHandler(object):
user_id, user_id,
time_now, time_now,
) )
opentracing.log_kv(
{
"message": "Updating device_keys for user.",
"user_id": user_id,
"device_id": device_id,
}
)
# TODO: Sign the JSON with the server key # TODO: Sign the JSON with the server key
changed = yield self.store.set_e2e_device_keys( changed = yield self.store.set_e2e_device_keys(
user_id, device_id, time_now, device_keys user_id, device_id, time_now, device_keys
@ -281,12 +400,27 @@ class E2eKeysHandler(object):
if changed: if changed:
# Only notify about device updates *if* the keys actually changed # Only notify about device updates *if* the keys actually changed
yield self.device_handler.notify_device_update(user_id, [device_id]) yield self.device_handler.notify_device_update(user_id, [device_id])
else:
opentracing.log_kv(
{"message": "Not updating device_keys for user", "user_id": user_id}
)
one_time_keys = keys.get("one_time_keys", None) one_time_keys = keys.get("one_time_keys", None)
opentracing.set_tag("one_time_keys", one_time_keys)
if one_time_keys: if one_time_keys:
opentracing.log_kv(
{
"message": "Updating one_time_keys for device.",
"user_id": user_id,
"device_id": device_id,
}
)
yield self._upload_one_time_keys_for_user( yield self._upload_one_time_keys_for_user(
user_id, device_id, time_now, one_time_keys user_id, device_id, time_now, one_time_keys
) )
else:
opentracing.log_kv(
{"message": "Did not update one_time_keys", "reason": "no keys given"}
)
# the device should have been registered already, but it may have been # the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an # deleted due to a race with a DELETE request. Or we may be using an
@ -297,6 +431,7 @@ class E2eKeysHandler(object):
result = yield self.store.count_e2e_one_time_keys(user_id, device_id) result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
opentracing.set_tag("one_time_key_counts", result)
defer.returnValue({"one_time_key_counts": result}) defer.returnValue({"one_time_key_counts": result})
@defer.inlineCallbacks @defer.inlineCallbacks
@ -340,6 +475,9 @@ class E2eKeysHandler(object):
(algorithm, key_id, encode_canonical_json(key).decode("ascii")) (algorithm, key_id, encode_canonical_json(key).decode("ascii"))
) )
opentracing.log_kv(
{"message": "Inserting new one_time_keys.", "keys": new_keys}
)
yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)