From 700f1ca681213e5a47223dfae3243e75cec23a31 Mon Sep 17 00:00:00 2001 From: Jorik Schellekens Date: Tue, 16 Jul 2019 17:01:18 +0100 Subject: [PATCH] Update the device list cache when keys/query is called We only cache it when it makes sense to do so. --- synapse/handlers/device.py | 160 +++++++++++++++++++---------------- synapse/handlers/e2e_keys.py | 148 ++++++++++++++++++++++++++++++-- 2 files changed, 231 insertions(+), 77 deletions(-) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 99e8413092..aae2fa09c8 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -18,6 +18,7 @@ from six import iteritems, itervalues from twisted.internet import defer +import synapse.logging.opentracing as opentracing from synapse.api import errors from synapse.api.constants import EventTypes from synapse.api.errors import ( @@ -211,12 +212,12 @@ class DeviceHandler(DeviceWorkerHandler): self.federation_sender = hs.get_federation_sender() - self._edu_updater = DeviceListEduUpdater(hs, self) + self.device_list_updater = DeviceListUpdater(hs, self) federation_registry = hs.get_federation_registry() federation_registry.register_edu_handler( - "m.device_list_update", self._edu_updater.incoming_device_list_update + "m.device_list_update", self.device_list_updater.incoming_device_list_update ) federation_registry.register_query_handler( "user_devices", self.on_federation_query_user_devices @@ -430,7 +431,7 @@ def _update_device_from_client_ips(device, client_ips): device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) -class DeviceListEduUpdater(object): +class DeviceListUpdater(object): "Handles incoming device list updates from federation and updates the DB" def __init__(self, hs, device_handler): @@ -523,75 +524,7 @@ class DeviceListEduUpdater(object): logger.debug("Need to re-sync devices for %r? %r", user_id, resync) if resync: - # Fetch all devices for the user. - origin = get_domain_from_id(user_id) - try: - result = yield self.federation.query_user_devices(origin, user_id) - except ( - NotRetryingDestination, - RequestSendFailed, - HttpResponseException, - ): - # TODO: Remember that we are now out of sync and try again - # later - logger.warn("Failed to handle device list update for %s", user_id) - # We abort on exceptions rather than accepting the update - # as otherwise synapse will 'forget' that its device list - # is out of date. If we bail then we will retry the resync - # next time we get a device list update for this user_id. - # This makes it more likely that the device lists will - # eventually become consistent. - return - except FederationDeniedError as e: - logger.info(e) - return - except Exception: - # TODO: Remember that we are now out of sync and try again - # later - logger.exception( - "Failed to handle device list update for %s", user_id - ) - return - - stream_id = result["stream_id"] - devices = result["devices"] - - # If the remote server has more than ~1000 devices for this user - # we assume that something is going horribly wrong (e.g. a bot - # that logs in and creates a new device every time it tries to - # send a message). Maintaining lots of devices per user in the - # cache can cause serious performance issues as if this request - # takes more than 60s to complete, internal replication from the - # inbound federation worker to the synapse master may time out - # causing the inbound federation to fail and causing the remote - # server to retry, causing a DoS. So in this scenario we give - # up on storing the total list of devices and only handle the - # delta instead. - if len(devices) > 1000: - logger.warn( - "Ignoring device list snapshot for %s as it has >1K devs (%d)", - user_id, - len(devices), - ) - devices = [] - - for device in devices: - logger.debug( - "Handling resync update %r/%r, ID: %r", - user_id, - device["device_id"], - stream_id, - ) - - yield self.store.update_remote_device_list_cache( - user_id, devices, stream_id - ) - device_ids = [device["device_id"] for device in devices] - yield self.device_handler.notify_device_update(user_id, device_ids) - - # We clobber the seen updates since we've re-synced from a given - # point. - self._seen_updates[user_id] = set([stream_id]) + yield self.user_device_resync(user_id) else: # Simply update the single device, since we know that is the only # change (because of the single prev_id matching the current cache) @@ -638,3 +571,86 @@ class DeviceListEduUpdater(object): stream_id_in_updates.add(stream_id) defer.returnValue(False) + + @opentracing.trace_deferred + @defer.inlineCallbacks + def user_device_resync(self, user_id): + """Fetches all devices for a user and updates the device cache with them. + + Args: + user_id (String): The user's id whose device_list will be updated. + Returns: + a dict with device info as under the "devices" in the result of this + request: + https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid + """ + opentracing.log_kv({"message": "Doing resync to update device list."}) + # Fetch all devices for the user. + origin = get_domain_from_id(user_id) + try: + result = yield self.federation.query_user_devices(origin, user_id) + except (NotRetryingDestination, RequestSendFailed, HttpResponseException): + # TODO: Remember that we are now out of sync and try again + # later + logger.warn("Failed to handle device list update for %s", user_id) + # We abort on exceptions rather than accepting the update + # as otherwise synapse will 'forget' that its device list + # is out of date. If we bail then we will retry the resync + # next time we get a device list update for this user_id. + # This makes it more likely that the device lists will + # eventually become consistent. + return + except FederationDeniedError as e: + opentracing.set_tag("error", True) + opentracing.log_kv({"reason": "FederationDeniedError"}) + logger.info(e) + return + except Exception as e: + # TODO: Remember that we are now out of sync and try again + # later + opentracing.set_tag("error", True) + opentracing.log_kv( + {"message": "Exception raised by federation request", "exception": e} + ) + logger.exception("Failed to handle device list update for %s", user_id) + return + opentracing.log_kv({"result": result}) + stream_id = result["stream_id"] + devices = result["devices"] + + # If the remote server has more than ~1000 devices for this user + # we assume that something is going horribly wrong (e.g. a bot + # that logs in and creates a new device every time it tries to + # send a message). Maintaining lots of devices per user in the + # cache can cause serious performance issues as if this request + # takes more than 60s to complete, internal replication from the + # inbound federation worker to the synapse master may time out + # causing the inbound federation to fail and causing the remote + # server to retry, causing a DoS. So in this scenario we give + # up on storing the total list of devices and only handle the + # delta instead. + if len(devices) > 1000: + logger.warn( + "Ignoring device list snapshot for %s as it has >1K devs (%d)", + user_id, + len(devices), + ) + devices = [] + + for device in devices: + logger.debug( + "Handling resync update %r/%r, ID: %r", + user_id, + device["device_id"], + stream_id, + ) + + yield self.store.update_remote_device_list_cache(user_id, devices, stream_id) + device_ids = [device["device_id"] for device in devices] + yield self.device_handler.notify_device_update(user_id, device_ids) + + # We clobber the seen updates since we've re-synced from a given + # point. + self._seen_updates[user_id] = set([stream_id]) + + defer.returnValue(result) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index fdfe8611b6..ca53b20321 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -22,8 +22,14 @@ from canonicaljson import encode_canonical_json, json from twisted.internet import defer +import synapse.logging.opentracing as opentracing from synapse.api.errors import CodeMessageException, SynapseError -from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.logging.context import ( + LoggingContext, + PreserveLoggingContext, + make_deferred_yieldable, + run_in_background, +) from synapse.types import UserID, get_domain_from_id from synapse.util.retryutils import NotRetryingDestination @@ -45,6 +51,7 @@ class E2eKeysHandler(object): "client_keys", self.on_federation_query_client_keys ) + @opentracing.trace_deferred @defer.inlineCallbacks def query_devices(self, query_body, timeout): """ Handle a device key query from a client @@ -65,6 +72,7 @@ class E2eKeysHandler(object): } } """ + device_keys_query = query_body.get("device_keys", {}) # separate users by domain. @@ -79,6 +87,9 @@ class E2eKeysHandler(object): else: remote_queries[user_id] = device_ids + opentracing.set_tag("local_key_query", local_query) + opentracing.set_tag("remote_key_query", remote_queries) + # First get local devices. failures = {} results = {} @@ -119,9 +130,81 @@ class E2eKeysHandler(object): r[user_id] = remote_queries[user_id] # Now fetch any devices that we don't have in our cache + @opentracing.trace_deferred @defer.inlineCallbacks def do_remote_query(destination): destination_query = remote_queries_not_in_cache[destination] + + opentracing.set_tag("key_query", destination_query) + + # We first consider whether we wish to update the dive list cache with + # the users device list. We want to track a user's devices when the + # authenticated user shares a room with the queried user and the query + # has not specified a particular device. + # If we update the cache for the queried user we remove them from further + # queries. We use the more efficient batched query_client_keys for all + # remaining users + user_ids_updated = [] + for (user_id, device_list) in destination_query.items(): + if user_id not in user_ids_updated: + try: + with PreserveLoggingContext(LoggingContext.current_context()): + room_ids = yield self.store.get_rooms_for_user(user_id) + if not device_list and room_ids: + opentracing.log_kv( + { + "message": "Resyncing devices for user", + "user_id": user_id, + } + ) + user_devices = yield self.device_handler.device_list_updater.user_device_resync( + user_id + ) + user_devices = user_devices["devices"] + opentracing.log_kv( + { + "message": "got user devices", + "user_devices": user_devices, + } + ) + for device in user_devices: + results[user_id] = { + device["device_id"]: device["keys"] + } + opentracing.log_kv( + {"adding user to user_ids_updated": user_id} + ) + user_ids_updated.append(user_id) + else: + opentracing.log_kv( + { + "message": "Not resyncing devices for user", + "user_id": user_id, + } + ) + except Exception as e: + failures[destination] = failures.get(destination, []).append( + _exception_to_failure(e) + ) + opentracing.set_tag("error", True) + opentracing.log_kv({"exception": e}) + + if len(destination_query) == len(user_ids_updated): + # We've updated all the users in the query and we do not need to + # make any further remote calls. + return + + # Remove all the users from the query which we have updated + for user_id in user_ids_updated: + destination_query.pop(user_id) + + opentracing.log_kv( + { + "message": "Querying remote servers for keys", + "destination_query": destination_query, + "not querying": user_ids_updated, + } + ) try: remote_result = yield self.federation.query_client_keys( destination, {"device_keys": destination_query}, timeout=timeout @@ -132,7 +215,8 @@ class E2eKeysHandler(object): results[user_id] = keys except Exception as e: - failures[destination] = _exception_to_failure(e) + failure = _exception_to_failure(e) + failures[destination] = failure yield make_deferred_yieldable( defer.gatherResults( @@ -143,9 +227,10 @@ class E2eKeysHandler(object): consumeErrors=True, ) ) - + opentracing.log_kv({"device_keys": results, "failures": failures}) defer.returnValue({"device_keys": results, "failures": failures}) + @opentracing.trace_deferred @defer.inlineCallbacks def query_local_devices(self, query): """Get E2E device keys for local users @@ -158,6 +243,7 @@ class E2eKeysHandler(object): defer.Deferred: (resolves to dict[string, dict[string, dict]]): map from user_id -> device_id -> device details """ + opentracing.set_tag("local_query", query) local_query = [] result_dict = {} @@ -165,6 +251,14 @@ class E2eKeysHandler(object): # we use UserID.from_string to catch invalid user ids if not self.is_mine(UserID.from_string(user_id)): logger.warning("Request for keys for non-local user %s", user_id) + opentracing.log_kv( + { + "message": "Requested a local key for a user which" + + " was not local to the homeserver", + "user_id": user_id, + } + ) + opentracing.set_tag("error", True) raise SynapseError(400, "Not a user here") if not device_ids: @@ -189,6 +283,7 @@ class E2eKeysHandler(object): r["unsigned"]["device_display_name"] = display_name result_dict[user_id][device_id] = r + opentracing.log_kv(results) defer.returnValue(result_dict) @defer.inlineCallbacks @@ -199,6 +294,7 @@ class E2eKeysHandler(object): res = yield self.query_local_devices(device_keys_query) defer.returnValue({"device_keys": res}) + @opentracing.trace_deferred @defer.inlineCallbacks def claim_one_time_keys(self, query, timeout): local_query = [] @@ -213,6 +309,9 @@ class E2eKeysHandler(object): domain = get_domain_from_id(user_id) remote_queries.setdefault(domain, {})[user_id] = device_keys + opentracing.set_tag("local_key_query", local_query) + opentracing.set_tag("remote_key_query", remote_queries) + results = yield self.store.claim_e2e_one_time_keys(local_query) json_result = {} @@ -224,8 +323,10 @@ class E2eKeysHandler(object): key_id: json.loads(json_bytes) } + @opentracing.trace_deferred @defer.inlineCallbacks def claim_client_keys(destination): + opentracing.set_tag("destination", destination) device_keys = remote_queries[destination] try: remote_result = yield self.federation.claim_client_keys( @@ -234,8 +335,12 @@ class E2eKeysHandler(object): for user_id, keys in remote_result["one_time_keys"].items(): if user_id in device_keys: json_result[user_id] = keys + except Exception as e: - failures[destination] = _exception_to_failure(e) + failure = _exception_to_failure(e) + failures[destination] = failure + opentracing.set_tag("error", True) + opentracing.set_tag("reason", failure) yield make_deferred_yieldable( defer.gatherResults( @@ -259,14 +364,21 @@ class E2eKeysHandler(object): ), ) + opentracing.log_kv({"one_time_keys": json_result, "failures": failures}) + defer.returnValue({"one_time_keys": json_result, "failures": failures}) + @opentracing.trace_deferred + @opentracing.tag_args @defer.inlineCallbacks + @opentracing.tag_args def upload_keys_for_user(self, user_id, device_id, keys): + time_now = self.clock.time_msec() # TODO: Validate the JSON to make sure it has the right keys. device_keys = keys.get("device_keys", None) + opentracing.set_tag("device_keys", device_keys) if device_keys: logger.info( "Updating device_keys for device %r for user %s at %d", @@ -274,6 +386,13 @@ class E2eKeysHandler(object): user_id, time_now, ) + opentracing.log_kv( + { + "message": "Updating device_keys for user.", + "user_id": user_id, + "device_id": device_id, + } + ) # TODO: Sign the JSON with the server key changed = yield self.store.set_e2e_device_keys( user_id, device_id, time_now, device_keys @@ -281,12 +400,27 @@ class E2eKeysHandler(object): if changed: # Only notify about device updates *if* the keys actually changed yield self.device_handler.notify_device_update(user_id, [device_id]) - + else: + opentracing.log_kv( + {"message": "Not updating device_keys for user", "user_id": user_id} + ) one_time_keys = keys.get("one_time_keys", None) + opentracing.set_tag("one_time_keys", one_time_keys) if one_time_keys: + opentracing.log_kv( + { + "message": "Updating one_time_keys for device.", + "user_id": user_id, + "device_id": device_id, + } + ) yield self._upload_one_time_keys_for_user( user_id, device_id, time_now, one_time_keys ) + else: + opentracing.log_kv( + {"message": "Did not update one_time_keys", "reason": "no keys given"} + ) # the device should have been registered already, but it may have been # deleted due to a race with a DELETE request. Or we may be using an @@ -297,6 +431,7 @@ class E2eKeysHandler(object): result = yield self.store.count_e2e_one_time_keys(user_id, device_id) + opentracing.set_tag("one_time_key_counts", result) defer.returnValue({"one_time_key_counts": result}) @defer.inlineCallbacks @@ -340,6 +475,9 @@ class E2eKeysHandler(object): (algorithm, key_id, encode_canonical_json(key).decode("ascii")) ) + opentracing.log_kv( + {"message": "Inserting new one_time_keys.", "keys": new_keys} + ) yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)