From 5ecc8b58255d7e33ad63a6c931efa6ed5e41ad01 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Aug 2020 10:51:42 -0400 Subject: [PATCH] Convert devices database to async/await. (#8069) --- changelog.d/8069.misc | 1 + synapse/storage/databases/main/devices.py | 333 ++++++++++++---------- tests/handlers/test_typing.py | 2 +- tests/storage/test_devices.py | 44 ++- tests/storage/test_end_to_end_keys.py | 16 +- 5 files changed, 220 insertions(+), 176 deletions(-) create mode 100644 changelog.d/8069.misc diff --git a/changelog.d/8069.misc b/changelog.d/8069.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8069.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 7a5f0bab05..2b33060480 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -15,9 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List, Optional, Set, Tuple - -from twisted.internet import defer +from typing import Dict, Iterable, List, Optional, Set, Tuple from synapse.api.errors import Codes, StoreError from synapse.logging.opentracing import ( @@ -33,14 +31,9 @@ from synapse.storage.database import ( LoggingTransaction, make_tuple_comparison_clause, ) -from synapse.types import Collection, get_verify_key_from_cross_signing_key +from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_encoder -from synapse.util.caches.descriptors import ( - Cache, - cached, - cachedInlineCallbacks, - cachedList, -) +from synapse.util.caches.descriptors import Cache, cached, cachedList from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -54,13 +47,13 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" class DeviceWorkerStore(SQLBaseStore): - def get_device(self, user_id, device_id): + def get_device(self, user_id: str, device_id: str): """Retrieve a device. Only returns devices that are not marked as hidden. Args: - user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to retrieve + user_id: The ID of the user which owns the device + device_id: The ID of the device to retrieve Returns: defer.Deferred for a dict containing the device information Raises: @@ -73,19 +66,17 @@ class DeviceWorkerStore(SQLBaseStore): desc="get_device", ) - @defer.inlineCallbacks - def get_devices_by_user(self, user_id): + async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]: """Retrieve all of a user's registered devices. Only returns devices that are not marked as hidden. Args: - user_id (str): + user_id: Returns: - defer.Deferred: resolves to a dict from device_id to a dict - containing "device_id", "user_id" and "display_name" for each - device. + A mapping from device_id to a dict containing "device_id", "user_id" + and "display_name" for each device. """ - devices = yield self.db_pool.simple_select_list( + devices = await self.db_pool.simple_select_list( table="devices", keyvalues={"user_id": user_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -95,19 +86,20 @@ class DeviceWorkerStore(SQLBaseStore): return {d["device_id"]: d for d in devices} @trace - @defer.inlineCallbacks - def get_device_updates_by_remote(self, destination, from_stream_id, limit): + async def get_device_updates_by_remote( + self, destination: str, from_stream_id: int, limit: int + ) -> Tuple[int, List[Tuple[str, dict]]]: """Get a stream of device updates to send to the given remote server. Args: - destination (str): The host the device updates are intended for - from_stream_id (int): The minimum stream_id to filter updates by, exclusive - limit (int): Maximum number of device updates to return + destination: The host the device updates are intended for + from_stream_id: The minimum stream_id to filter updates by, exclusive + limit: Maximum number of device updates to return + Returns: - Deferred[tuple[int, list[tuple[string,dict]]]]: - current stream id (ie, the stream id of the last update included in the - response), and the list of updates, where each update is a pair of EDU - type and EDU contents + A mapping from the current stream id (ie, the stream id of the last + update included in the response), and the list of updates, where + each update is a pair of EDU type and EDU contents. """ now_stream_id = self._device_list_id_gen.get_current_token() @@ -117,7 +109,7 @@ class DeviceWorkerStore(SQLBaseStore): if not has_changed: return now_stream_id, [] - updates = yield self.db_pool.runInteraction( + updates = await self.db_pool.runInteraction( "get_device_updates_by_remote", self._get_device_updates_by_remote_txn, destination, @@ -136,9 +128,7 @@ class DeviceWorkerStore(SQLBaseStore): master_key_by_user = {} self_signing_key_by_user = {} for user in users: - cross_signing_key = yield defer.ensureDeferred( - self.get_e2e_cross_signing_key(user, "master") - ) + cross_signing_key = await self.get_e2e_cross_signing_key(user, "master") if cross_signing_key: key_id, verify_key = get_verify_key_from_cross_signing_key( cross_signing_key @@ -151,8 +141,8 @@ class DeviceWorkerStore(SQLBaseStore): "device_id": verify_key.version, } - cross_signing_key = yield defer.ensureDeferred( - self.get_e2e_cross_signing_key(user, "self_signing") + cross_signing_key = await self.get_e2e_cross_signing_key( + user, "self_signing" ) if cross_signing_key: key_id, verify_key = get_verify_key_from_cross_signing_key( @@ -202,7 +192,7 @@ class DeviceWorkerStore(SQLBaseStore): if update_stream_id > previous_update_stream_id: query_map[key] = (update_stream_id, update_context) - results = yield self._get_device_update_edus_by_remote( + results = await self._get_device_update_edus_by_remote( destination, from_stream_id, query_map ) @@ -215,16 +205,21 @@ class DeviceWorkerStore(SQLBaseStore): return now_stream_id, results def _get_device_updates_by_remote_txn( - self, txn, destination, from_stream_id, now_stream_id, limit + self, + txn: LoggingTransaction, + destination: str, + from_stream_id: int, + now_stream_id: int, + limit: int, ): """Return device update information for a given remote destination Args: - txn (LoggingTransaction): The transaction to execute - destination (str): The host the device updates are intended for - from_stream_id (int): The minimum stream_id to filter updates by, exclusive - now_stream_id (int): The maximum stream_id to filter updates by, inclusive - limit (int): Maximum number of device updates to return + txn: The transaction to execute + destination: The host the device updates are intended for + from_stream_id: The minimum stream_id to filter updates by, exclusive + now_stream_id: The maximum stream_id to filter updates by, inclusive + limit: Maximum number of device updates to return Returns: List: List of device updates @@ -240,23 +235,26 @@ class DeviceWorkerStore(SQLBaseStore): return list(txn) - @defer.inlineCallbacks - def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map): + async def _get_device_update_edus_by_remote( + self, + destination: str, + from_stream_id: int, + query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]], + ) -> List[Tuple[str, dict]]: """Returns a list of device update EDUs as well as E2EE keys Args: - destination (str): The host the device updates are intended for - from_stream_id (int): The minimum stream_id to filter updates by, exclusive + destination: The host the device updates are intended for + from_stream_id: The minimum stream_id to filter updates by, exclusive query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping user_id/device_id to update stream_id and the relevant json-encoded opentracing context Returns: - List[Dict]: List of objects representing an device update EDU - + List of objects representing an device update EDU """ devices = ( - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_get_e2e_device_keys_txn", self._get_e2e_device_keys_txn, query_map.keys(), @@ -271,7 +269,7 @@ class DeviceWorkerStore(SQLBaseStore): for user_id, user_devices in devices.items(): # The prev_id for the first row is always the last row before # `from_stream_id` - prev_id = yield self._get_last_device_update_for_remote_user( + prev_id = await self._get_last_device_update_for_remote_user( destination, user_id, from_stream_id ) @@ -315,7 +313,7 @@ class DeviceWorkerStore(SQLBaseStore): return results def _get_last_device_update_for_remote_user( - self, destination, user_id, from_stream_id + self, destination: str, user_id: str, from_stream_id: int ): def f(txn): prev_sent_id_sql = """ @@ -329,7 +327,7 @@ class DeviceWorkerStore(SQLBaseStore): return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f) - def mark_as_sent_devices_by_remote(self, destination, stream_id): + def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int): """Mark that updates have successfully been sent to the destination. """ return self.db_pool.runInteraction( @@ -339,7 +337,9 @@ class DeviceWorkerStore(SQLBaseStore): stream_id, ) - def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): + def _mark_as_sent_devices_by_remote_txn( + self, txn: LoggingTransaction, destination: str, stream_id: int + ) -> None: # We update the device_lists_outbound_last_success with the successfully # poked users. sql = """ @@ -367,17 +367,21 @@ class DeviceWorkerStore(SQLBaseStore): """ txn.execute(sql, (destination, stream_id)) - @defer.inlineCallbacks - def add_user_signature_change_to_streams(self, from_user_id, user_ids): + async def add_user_signature_change_to_streams( + self, from_user_id: str, user_ids: List[str] + ) -> int: """Persist that a user has made new signatures Args: - from_user_id (str): the user who made the signatures - user_ids (list[str]): the users who were signed + from_user_id: the user who made the signatures + user_ids: the users who were signed + + Returns: + THe new stream ID. """ with self._device_list_id_gen.get_next() as stream_id: - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, from_user_id, @@ -386,7 +390,13 @@ class DeviceWorkerStore(SQLBaseStore): ) return stream_id - def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id): + def _add_user_signature_change_txn( + self, + txn: LoggingTransaction, + from_user_id: str, + user_ids: List[str], + stream_id: int, + ) -> None: txn.call_after( self._user_signature_stream_cache.entity_has_changed, from_user_id, @@ -402,29 +412,30 @@ class DeviceWorkerStore(SQLBaseStore): }, ) - def get_device_stream_token(self): + def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() @trace - @defer.inlineCallbacks - def get_user_devices_from_cache(self, query_list): + async def get_user_devices_from_cache( + self, query_list: List[Tuple[str, str]] + ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]: """Get the devices (and keys if any) for remote users from the cache. Args: - query_list(list): List of (user_id, device_ids), if device_ids is + query_list: List of (user_id, device_ids), if device_ids is falsey then return all device ids for that user. Returns: - (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is - a set of user_ids and results_map is a mapping of - user_id -> device_id -> device_info + A tuple of (user_ids_not_in_cache, results_map), where + user_ids_not_in_cache is a set of user_ids and results_map is a + mapping of user_id -> device_id -> device_info. """ user_ids = {user_id for user_id, _ in query_list} - user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids)) + user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids)) # We go and check if any of the users need to have their device lists # resynced. If they do then we remove them from the cached list. - users_needing_resync = yield self.get_user_ids_requiring_device_list_resync( + users_needing_resync = await self.get_user_ids_requiring_device_list_resync( user_ids ) user_ids_in_cache = { @@ -438,19 +449,19 @@ class DeviceWorkerStore(SQLBaseStore): continue if device_id: - device = yield self._get_cached_user_device(user_id, device_id) + device = await self._get_cached_user_device(user_id, device_id) results.setdefault(user_id, {})[device_id] = device else: - results[user_id] = yield self.get_cached_devices_for_user(user_id) + results[user_id] = await self.get_cached_devices_for_user(user_id) set_tag("in_cache", results) set_tag("not_in_cache", user_ids_not_in_cache) return user_ids_not_in_cache, results - @cachedInlineCallbacks(num_args=2, tree=True) - def _get_cached_user_device(self, user_id, device_id): - content = yield self.db_pool.simple_select_one_onecol( + @cached(num_args=2, tree=True) + async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict: + content = await self.db_pool.simple_select_one_onecol( table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="content", @@ -458,9 +469,9 @@ class DeviceWorkerStore(SQLBaseStore): ) return db_to_json(content) - @cachedInlineCallbacks() - def get_cached_devices_for_user(self, user_id): - devices = yield self.db_pool.simple_select_list( + @cached() + async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]: + devices = await self.db_pool.simple_select_list( table="device_lists_remote_cache", keyvalues={"user_id": user_id}, retcols=("device_id", "content"), @@ -470,11 +481,11 @@ class DeviceWorkerStore(SQLBaseStore): device["device_id"]: db_to_json(device["content"]) for device in devices } - def get_devices_with_keys_by_user(self, user_id): + def get_devices_with_keys_by_user(self, user_id: str): """Get all devices (with any device keys) for a user Returns: - (stream_id, devices) + Deferred which resolves to (stream_id, devices) """ return self.db_pool.runInteraction( "get_devices_with_keys_by_user", @@ -482,7 +493,9 @@ class DeviceWorkerStore(SQLBaseStore): user_id, ) - def _get_devices_with_keys_by_user_txn(self, txn, user_id): + def _get_devices_with_keys_by_user_txn( + self, txn: LoggingTransaction, user_id: str + ) -> Tuple[int, List[JsonDict]]: now_stream_id = self._device_list_id_gen.get_current_token() devices = self._get_e2e_device_keys_txn( @@ -515,17 +528,18 @@ class DeviceWorkerStore(SQLBaseStore): return now_stream_id, [] - def get_users_whose_devices_changed(self, from_key, user_ids): + async def get_users_whose_devices_changed( + self, from_key: str, user_ids: Iterable[str] + ) -> Set[str]: """Get set of users whose devices have changed since `from_key` that are in the given list of user_ids. Args: - from_key (str): The device lists stream token - user_ids (Iterable[str]) + from_key: The device lists stream token + user_ids: The user IDs to query for devices. Returns: - Deferred[set[str]]: The set of user_ids whose devices have changed - since `from_key` + The set of user_ids whose devices have changed since `from_key` """ from_key = int(from_key) @@ -536,7 +550,7 @@ class DeviceWorkerStore(SQLBaseStore): ) if not to_check: - return defer.succeed(set()) + return set() def _get_users_whose_devices_changed_txn(txn): changes = set() @@ -556,18 +570,22 @@ class DeviceWorkerStore(SQLBaseStore): return changes - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn ) - @defer.inlineCallbacks - def get_users_whose_signatures_changed(self, user_id, from_key): + async def get_users_whose_signatures_changed( + self, user_id: str, from_key: str + ) -> Set[str]: """Get the users who have new cross-signing signatures made by `user_id` since `from_key`. Args: - user_id (str): the user who made the signatures - from_key (str): The device lists stream token + user_id: the user who made the signatures + from_key: The device lists stream token + + Returns: + A set of user IDs with updated signatures. """ from_key = int(from_key) if self._user_signature_stream_cache.has_entity_changed(user_id, from_key): @@ -575,7 +593,7 @@ class DeviceWorkerStore(SQLBaseStore): SELECT DISTINCT user_ids FROM user_signature_stream WHERE from_user_id = ? AND stream_id > ? """ - rows = yield self.db_pool.execute( + rows = await self.db_pool.execute( "get_users_whose_signatures_changed", None, sql, user_id, from_key ) return {user for row in rows for user in db_to_json(row[0])} @@ -638,7 +656,7 @@ class DeviceWorkerStore(SQLBaseStore): ) @cached(max_entries=10000) - def get_device_list_last_stream_id_for_remote(self, user_id): + def get_device_list_last_stream_id_for_remote(self, user_id: str): """Get the last stream_id we got for a user. May be None if we haven't got any information for them. """ @@ -655,7 +673,7 @@ class DeviceWorkerStore(SQLBaseStore): list_name="user_ids", inlineCallbacks=True, ) - def get_device_list_last_stream_id_for_remotes(self, user_ids): + def get_device_list_last_stream_id_for_remotes(self, user_ids: str): rows = yield self.db_pool.simple_select_many_batch( table="device_lists_remote_extremeties", column="user_id", @@ -669,8 +687,7 @@ class DeviceWorkerStore(SQLBaseStore): return results - @defer.inlineCallbacks - def get_user_ids_requiring_device_list_resync( + async def get_user_ids_requiring_device_list_resync( self, user_ids: Optional[Collection[str]] = None, ) -> Set[str]: """Given a list of remote users return the list of users that we @@ -681,7 +698,7 @@ class DeviceWorkerStore(SQLBaseStore): The IDs of users whose device lists need resync. """ if user_ids: - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="device_lists_remote_resync", column="user_id", iterable=user_ids, @@ -689,7 +706,7 @@ class DeviceWorkerStore(SQLBaseStore): desc="get_user_ids_requiring_device_list_resync_with_iterable", ) else: - rows = yield self.db_pool.simple_select_list( + rows = await self.db_pool.simple_select_list( table="device_lists_remote_resync", keyvalues=None, retcols=("user_id",), @@ -710,7 +727,7 @@ class DeviceWorkerStore(SQLBaseStore): desc="make_remote_user_device_cache_as_stale", ) - def mark_remote_user_device_list_as_unsubscribed(self, user_id): + def mark_remote_user_device_list_as_unsubscribed(self, user_id: str): """Mark that we no longer track device lists for remote user. """ @@ -779,16 +796,15 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): "drop_device_lists_outbound_last_success_non_unique_idx", ) - @defer.inlineCallbacks - def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): + async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): def f(conn): txn = conn.cursor() txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id") txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") txn.close() - yield self.db_pool.runWithConnection(f) - yield self.db_pool.updates._end_background_update( + await self.db_pool.runWithConnection(f) + await self.db_pool.updates._end_background_update( DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES ) return 1 @@ -868,18 +884,20 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000) - @defer.inlineCallbacks - def store_device(self, user_id, device_id, initial_device_display_name): + async def store_device( + self, user_id: str, device_id: str, initial_device_display_name: str + ) -> bool: """Ensure the given device is known; add it to the store if not Args: - user_id (str): id of user associated with the device - device_id (str): id of device - initial_device_display_name (str): initial displayname of the - device. Ignored if device exists. + user_id: id of user associated with the device + device_id: id of device + initial_device_display_name: initial displayname of the device. + Ignored if device exists. + Returns: - defer.Deferred: boolean whether the device was inserted or an - existing device existed with that ID. + Whether the device was inserted or an existing device existed with that ID. + Raises: StoreError: if the device is already in use """ @@ -888,7 +906,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return False try: - inserted = yield self.db_pool.simple_insert( + inserted = await self.db_pool.simple_insert( "devices", values={ "user_id": user_id, @@ -902,7 +920,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not inserted: # if the device already exists, check if it's a real device, or # if the device ID is reserved by something else - hidden = yield self.db_pool.simple_select_one_onecol( + hidden = await self.db_pool.simple_select_one_onecol( "devices", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="hidden", @@ -927,17 +945,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) raise StoreError(500, "Problem storing device.") - @defer.inlineCallbacks - def delete_device(self, user_id, device_id): + async def delete_device(self, user_id: str, device_id: str) -> None: """Delete a device. Args: - user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to delete - Returns: - defer.Deferred + user_id: The ID of the user which owns the device + device_id: The ID of the device to delete """ - yield self.db_pool.simple_delete_one( + await self.db_pool.simple_delete_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, desc="delete_device", @@ -945,17 +960,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.device_id_exists_cache.invalidate((user_id, device_id)) - @defer.inlineCallbacks - def delete_devices(self, user_id, device_ids): + async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: """Deletes several devices. Args: - user_id (str): The ID of the user which owns the devices - device_ids (list): The IDs of the devices to delete - Returns: - defer.Deferred + user_id: The ID of the user which owns the devices + device_ids: The IDs of the devices to delete """ - yield self.db_pool.simple_delete_many( + await self.db_pool.simple_delete_many( table="devices", column="device_id", iterable=device_ids, @@ -965,26 +977,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) - def update_device(self, user_id, device_id, new_display_name=None): + async def update_device( + self, user_id: str, device_id: str, new_display_name: Optional[str] = None + ) -> None: """Update a device. Only updates the device if it is not marked as hidden. Args: - user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to update - new_display_name (str|None): new displayname for device; None - to leave unchanged + user_id: The ID of the user which owns the device + device_id: The ID of the device to update + new_display_name: new displayname for device; None to leave unchanged Raises: StoreError: if the device is not found - Returns: - defer.Deferred """ updates = {} if new_display_name is not None: updates["display_name"] = new_display_name if not updates: - return defer.succeed(None) - return self.db_pool.simple_update_one( + return None + await self.db_pool.simple_update_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, updatevalues=updates, @@ -992,7 +1003,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) def update_remote_device_list_cache_entry( - self, user_id, device_id, content, stream_id + self, user_id: str, device_id: str, content: JsonDict, stream_id: int ): """Updates a single device in the cache of a remote user's devicelist. @@ -1000,10 +1011,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): device list. Args: - user_id (str): User to update device list for - device_id (str): ID of decivice being updated - content (dict): new data on this device - stream_id (int): the version of the device list + user_id: User to update device list for + device_id: ID of decivice being updated + content: new data on this device + stream_id: the version of the device list Returns: Deferred[None] @@ -1018,8 +1029,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) def _update_remote_device_list_cache_entry_txn( - self, txn, user_id, device_id, content, stream_id - ): + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + content: JsonDict, + stream_id: int, + ) -> None: if content.get("deleted"): self.db_pool.simple_delete_txn( txn, @@ -1055,16 +1071,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): lock=False, ) - def update_remote_device_list_cache(self, user_id, devices, stream_id): + def update_remote_device_list_cache( + self, user_id: str, devices: List[dict], stream_id: int + ): """Replace the entire cache of the remote user's devices. Note: assumes that we are the only thread that can be updating this user's device list. Args: - user_id (str): User to update device list for - devices (list[dict]): list of device objects supplied over federation - stream_id (int): the version of the device list + user_id: User to update device list for + devices: list of device objects supplied over federation + stream_id: the version of the device list Returns: Deferred[None] @@ -1077,7 +1095,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): stream_id, ) - def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id): + def _update_remote_device_list_cache_txn( + self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int + ): self.db_pool.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} ) @@ -1118,8 +1138,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id}, ) - @defer.inlineCallbacks - def add_device_change_to_streams(self, user_id, device_ids, hosts): + async def add_device_change_to_streams( + self, user_id: str, device_ids: Collection[str], hosts: List[str] + ): """Persist that a user's devices have been updated, and which hosts (if any) should be poked. """ @@ -1127,7 +1148,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids: - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_device_change_to_stream", self._add_device_change_to_stream_txn, user_id, @@ -1142,7 +1163,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): with self._device_list_id_gen.get_next_mult( len(hosts) * len(device_ids) ) as stream_ids: - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_device_outbound_poke_to_stream", self._add_device_outbound_poke_to_stream_txn, user_id, @@ -1187,7 +1208,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) def _add_device_outbound_poke_to_stream_txn( - self, txn, user_id, device_ids, hosts, stream_ids, context, + self, + txn: LoggingTransaction, + user_id: str, + device_ids: Collection[str], + hosts: List[str], + stream_ids: List[str], + context: Dict[str, str], ): for host in hosts: txn.call_after( @@ -1219,7 +1246,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ], ) - def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000): + def _prune_old_outbound_device_pokes(self, prune_age: int = 24 * 60 * 60 * 1000): """Delete old entries out of the device_lists_outbound_pokes to ensure that we don't fill up due to dead servers. diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 64ddd8243d..64afd581bc 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -116,7 +116,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): retry_timings_res ) - self.datastore.get_device_updates_by_remote.return_value = defer.succeed( + self.datastore.get_device_updates_by_remote.side_effect = lambda destination, from_stream_id, limit: make_awaitable( (0, []) ) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index c2539b353a..87ed8f8cd1 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -34,7 +34,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_store_new_device(self): - yield self.store.store_device("user_id", "device_id", "display_name") + yield defer.ensureDeferred( + self.store.store_device("user_id", "device_id", "display_name") + ) res = yield self.store.get_device("user_id", "device_id") self.assertDictContainsSubset( @@ -48,11 +50,17 @@ class DeviceStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_get_devices_by_user(self): - yield self.store.store_device("user_id", "device1", "display_name 1") - yield self.store.store_device("user_id", "device2", "display_name 2") - yield self.store.store_device("user_id2", "device3", "display_name 3") + yield defer.ensureDeferred( + self.store.store_device("user_id", "device1", "display_name 1") + ) + yield defer.ensureDeferred( + self.store.store_device("user_id", "device2", "display_name 2") + ) + yield defer.ensureDeferred( + self.store.store_device("user_id2", "device3", "display_name 3") + ) - res = yield self.store.get_devices_by_user("user_id") + res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id")) self.assertEqual(2, len(res.keys())) self.assertDictContainsSubset( { @@ -76,13 +84,13 @@ class DeviceStoreTestCase(tests.unittest.TestCase): device_ids = ["device_id1", "device_id2"] # Add two device updates with a single stream_id - yield self.store.add_device_change_to_streams( - "user_id", device_ids, ["somehost"] + yield defer.ensureDeferred( + self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) ) # Get all device updates ever meant for this remote - now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( - "somehost", -1, limit=100 + now_stream_id, device_updates = yield defer.ensureDeferred( + self.store.get_device_updates_by_remote("somehost", -1, limit=100) ) # Check original device_ids are contained within these updates @@ -99,19 +107,23 @@ class DeviceStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_update_device(self): - yield self.store.store_device("user_id", "device_id", "display_name 1") + yield defer.ensureDeferred( + self.store.store_device("user_id", "device_id", "display_name 1") + ) res = yield self.store.get_device("user_id", "device_id") self.assertEqual("display_name 1", res["display_name"]) # do a no-op first - yield self.store.update_device("user_id", "device_id") + yield defer.ensureDeferred(self.store.update_device("user_id", "device_id")) res = yield self.store.get_device("user_id", "device_id") self.assertEqual("display_name 1", res["display_name"]) # do the update - yield self.store.update_device( - "user_id", "device_id", new_display_name="display_name 2" + yield defer.ensureDeferred( + self.store.update_device( + "user_id", "device_id", new_display_name="display_name 2" + ) ) # check it worked @@ -121,7 +133,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_update_unknown_device(self): with self.assertRaises(synapse.api.errors.StoreError) as cm: - yield self.store.update_device( - "user_id", "unknown_device_id", new_display_name="display_name 2" + yield defer.ensureDeferred( + self.store.update_device( + "user_id", "unknown_device_id", new_display_name="display_name 2" + ) ) self.assertEqual(404, cm.exception.code) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 9f8d30373b..d57cdffd8b 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -30,7 +30,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): now = 1470174257070 json = {"key": "value"} - yield self.store.store_device("user", "device", None) + yield defer.ensureDeferred(self.store.store_device("user", "device", None)) yield self.store.set_e2e_device_keys("user", "device", now, json) @@ -47,7 +47,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): now = 1470174257070 json = {"key": "value"} - yield self.store.store_device("user", "device", None) + yield defer.ensureDeferred(self.store.store_device("user", "device", None)) changed = yield self.store.set_e2e_device_keys("user", "device", now, json) self.assertTrue(changed) @@ -63,7 +63,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): json = {"key": "value"} yield self.store.set_e2e_device_keys("user", "device", now, json) - yield self.store.store_device("user", "device", "display_name") + yield defer.ensureDeferred( + self.store.store_device("user", "device", "display_name") + ) res = yield defer.ensureDeferred( self.store.get_e2e_device_keys((("user", "device"),)) @@ -79,10 +81,10 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): def test_multiple_devices(self): now = 1470174257070 - yield self.store.store_device("user1", "device1", None) - yield self.store.store_device("user1", "device2", None) - yield self.store.store_device("user2", "device1", None) - yield self.store.store_device("user2", "device2", None) + yield defer.ensureDeferred(self.store.store_device("user1", "device1", None)) + yield defer.ensureDeferred(self.store.store_device("user1", "device2", None)) + yield defer.ensureDeferred(self.store.store_device("user2", "device1", None)) + yield defer.ensureDeferred(self.store.store_device("user2", "device2", None)) yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"}) yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})