diff --git a/.github/workflows/fix_lint.yaml b/.github/workflows/fix_lint.yaml index 5970b4e826..909b0a847f 100644 --- a/.github/workflows/fix_lint.yaml +++ b/.github/workflows/fix_lint.yaml @@ -29,10 +29,14 @@ jobs: with: install-project: "false" - - name: Run ruff + - name: Run ruff check continue-on-error: true run: poetry run ruff check --fix . + - name: Run ruff format + continue-on-error: true + run: poetry run ruff format --quiet . + - run: cargo clippy --all-features --fix -- -D warnings continue-on-error: true diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index add046ec6a..5586bd6d94 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -131,9 +131,12 @@ jobs: with: install-project: "false" - - name: Check style + - name: Run ruff check run: poetry run ruff check --output-format=github . + - name: Run ruff format + run: poetry run ruff format --check . + lint-mypy: runs-on: ubuntu-latest name: Typechecking diff --git a/changelog.d/17643.misc b/changelog.d/17643.misc new file mode 100644 index 0000000000..f583cdcb38 --- /dev/null +++ b/changelog.d/17643.misc @@ -0,0 +1 @@ +Replace `isort` and `black with `ruff`. diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py index d4ddeb4dc7..ca2e72b5e8 100755 --- a/contrib/cmdclient/console.py +++ b/contrib/cmdclient/console.py @@ -21,7 +21,8 @@ # # -""" Starts a synapse client console. """ +"""Starts a synapse client console.""" + import argparse import binascii import cmd diff --git a/scripts-dev/check_pydantic_models.py b/scripts-dev/check_pydantic_models.py index 9e67375b6a..26d667aba0 100755 --- a/scripts-dev/check_pydantic_models.py +++ b/scripts-dev/check_pydantic_models.py @@ -31,6 +31,7 @@ Pydantic does not yet offer a strict mode, but it is planned for pydantic v2. Se until then, this script is a best effort to stop us from introducing type coersion bugs (like the infamous stringy power levels fixed in room version 10). """ + import argparse import contextlib import functools diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index fa6ff90708..c656047729 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -109,6 +109,9 @@ set -x # --quiet suppresses the update check. ruff check --quiet --fix "${files[@]}" +# Reformat Python code. +ruff format --quiet "${files[@]}" + # Catch any common programming mistakes in Rust code. # # --bins, --examples, --lib, --tests combined explicitly disable checking diff --git a/scripts-dev/release.py b/scripts-dev/release.py index 1ace804682..4435624267 100755 --- a/scripts-dev/release.py +++ b/scripts-dev/release.py @@ -20,8 +20,7 @@ # # -"""An interactive script for doing a release. See `cli()` below. -""" +"""An interactive script for doing a release. See `cli()` below.""" import glob import json diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index a141218d3d..c9a4114b1e 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Contains *incomplete* type hints for txredisapi. -""" +"""Contains *incomplete* type hints for txredisapi.""" + from typing import Any, List, Optional, Type, Union from twisted.internet import protocol diff --git a/synapse/__init__.py b/synapse/__init__.py index 99ed7a5374..73b92f12be 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -20,8 +20,7 @@ # # -""" This is an implementation of a Matrix homeserver. -""" +"""This is an implementation of a Matrix homeserver.""" import os import sys diff --git a/synapse/_scripts/generate_workers_map.py b/synapse/_scripts/generate_workers_map.py index 715c7ddc17..09feb8cf30 100755 --- a/synapse/_scripts/generate_workers_map.py +++ b/synapse/_scripts/generate_workers_map.py @@ -171,7 +171,7 @@ def elide_http_methods_if_unconflicting( """ def paths_to_methods_dict( - methods_and_paths: Iterable[Tuple[str, str]] + methods_and_paths: Iterable[Tuple[str, str]], ) -> Dict[str, Set[str]]: """ Given (method, path) pairs, produces a dict from path to set of methods @@ -201,7 +201,7 @@ def elide_http_methods_if_unconflicting( def simplify_path_regexes( - registrations: Dict[Tuple[str, str], EndpointDescription] + registrations: Dict[Tuple[str, str], EndpointDescription], ) -> Dict[Tuple[str, str], EndpointDescription]: """ Simplify all the path regexes for the dict of endpoint descriptions, diff --git a/synapse/_scripts/review_recent_signups.py b/synapse/_scripts/review_recent_signups.py index ad88df477a..62723c539d 100644 --- a/synapse/_scripts/review_recent_signups.py +++ b/synapse/_scripts/review_recent_signups.py @@ -40,6 +40,7 @@ from synapse.storage.engines import create_engine class ReviewConfig(RootConfig): "A config class that just pulls out the database config" + config_classes = [DatabaseConfig] @@ -160,7 +161,11 @@ def main() -> None: with make_conn(database_config, engine, "review_recent_signups") as db_conn: # This generates a type of Cursor, not LoggingTransaction. - user_infos = get_recent_users(db_conn.cursor(), since_ms, exclude_users_with_appservice) # type: ignore[arg-type] + user_infos = get_recent_users( + db_conn.cursor(), + since_ms, # type: ignore[arg-type] + exclude_users_with_appservice, + ) for user_info in user_infos: if exclude_users_with_email and user_info.emails: diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 195c95d376..31639d366e 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -717,9 +717,7 @@ class Porter: return # Check if all background updates are done, abort if not. - updates_complete = ( - await self.sqlite_store.db_pool.updates.has_completed_background_updates() - ) + updates_complete = await self.sqlite_store.db_pool.updates.has_completed_background_updates() if not updates_complete: end_error = ( "Pending background updates exist in the SQLite3 database." @@ -1095,10 +1093,10 @@ class Porter: return done, remaining + done async def _setup_state_group_id_seq(self) -> None: - curr_id: Optional[int] = ( - await self.sqlite_store.db_pool.simple_select_one_onecol( - table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True - ) + curr_id: Optional[ + int + ] = await self.sqlite_store.db_pool.simple_select_one_onecol( + table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True ) if not curr_id: @@ -1186,13 +1184,13 @@ class Porter: ) async def _setup_auth_chain_sequence(self) -> None: - curr_chain_id: Optional[int] = ( - await self.sqlite_store.db_pool.simple_select_one_onecol( - table="event_auth_chains", - keyvalues={}, - retcol="MAX(chain_id)", - allow_none=True, - ) + curr_chain_id: Optional[ + int + ] = await self.sqlite_store.db_pool.simple_select_one_onecol( + table="event_auth_chains", + keyvalues={}, + retcol="MAX(chain_id)", + allow_none=True, ) def r(txn: LoggingTransaction) -> None: diff --git a/synapse/api/urls.py b/synapse/api/urls.py index d077a2c613..03a3e96f28 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -19,7 +19,8 @@ # # -"""Contains the URL paths to prefix various aspects of the server with. """ +"""Contains the URL paths to prefix various aspects of the server with.""" + import hmac from hashlib import sha256 from urllib.parse import urlencode diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index bec83419a2..7994da0868 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -54,6 +54,7 @@ UP & quit +---------- YES SUCCESS This is all tied together by the AppServiceScheduler which DIs the required components. """ + import logging from typing import ( TYPE_CHECKING, diff --git a/synapse/config/key.py b/synapse/config/key.py index b9925a52d2..bc96888967 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -200,16 +200,13 @@ class KeyConfig(Config): ) form_secret = 'form_secret: "%s"' % random_string_with_symbols(50) - return ( - """\ + return """\ %(macaroon_secret_key)s %(form_secret)s signing_key_path: "%(base_key_name)s.signing.key" trusted_key_servers: - server_name: "matrix.org" - """ - % locals() - ) + """ % locals() def read_signing_keys(self, signing_key_path: str, name: str) -> List[SigningKey]: """Read the signing keys in the given path. @@ -249,7 +246,9 @@ class KeyConfig(Config): if is_signing_algorithm_supported(key_id): key_base64 = key_data["key"] key_bytes = decode_base64(key_base64) - verify_key: "VerifyKeyWithExpiry" = decode_verify_key_bytes(key_id, key_bytes) # type: ignore[assignment] + verify_key: "VerifyKeyWithExpiry" = decode_verify_key_bytes( + key_id, key_bytes + ) # type: ignore[assignment] verify_key.expired = key_data["expired_ts"] keys[key_id] = verify_key else: diff --git a/synapse/config/logger.py b/synapse/config/logger.py index fca0b08d6d..cfc1a57107 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -157,12 +157,9 @@ class LoggingConfig(Config): self, config_dir_path: str, server_name: str, **kwargs: Any ) -> str: log_config = os.path.join(config_dir_path, server_name + ".log.config") - return ( - """\ + return """\ log_config: "%(log_config)s" - """ - % locals() - ) + """ % locals() def read_arguments(self, args: argparse.Namespace) -> None: if args.no_redirect_stdio is not None: diff --git a/synapse/config/server.py b/synapse/config/server.py index fd52c0475c..488604a30c 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -828,13 +828,10 @@ class ServerConfig(Config): ).lstrip() if not unsecure_listeners: - unsecure_http_bindings = ( - """- port: %(unsecure_port)s + unsecure_http_bindings = """- port: %(unsecure_port)s tls: false type: http - x_forwarded: true""" - % locals() - ) + x_forwarded: true""" % locals() if not open_private_ports: unsecure_http_bindings += ( @@ -853,16 +850,13 @@ class ServerConfig(Config): if not secure_listeners: secure_http_bindings = "" - return ( - """\ + return """\ server_name: "%(server_name)s" pid_file: %(pid_file)s listeners: %(secure_http_bindings)s %(unsecure_http_bindings)s - """ - % locals() - ) + """ % locals() def read_arguments(self, args: argparse.Namespace) -> None: if args.manhole is not None: diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 7ecf349e4a..b013ffa354 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -328,10 +328,11 @@ class WorkerConfig(Config): ) # type-ignore: the expression `Union[A, B]` is not a Type[Union[A, B]] currently - self.instance_map: Dict[ - str, InstanceLocationConfig - ] = parse_and_validate_mapping( - instance_map, InstanceLocationConfig # type: ignore[arg-type] + self.instance_map: Dict[str, InstanceLocationConfig] = ( + parse_and_validate_mapping( + instance_map, + InstanceLocationConfig, # type: ignore[arg-type] + ) ) # Map from type of streams to source, c.f. WriterLocations. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index f5abcde2db..b834547d11 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -887,7 +887,8 @@ def _check_power_levels( raise SynapseError(400, f"{v!r} must be an integer.") if k in {"events", "notifications", "users"}: if not isinstance(v, collections.abc.Mapping) or not all( - type(v) is int for v in v.values() # noqa: E721 + type(v) is int + for v in v.values() # noqa: E721 ): raise SynapseError( 400, diff --git a/synapse/events/presence_router.py b/synapse/events/presence_router.py index 9cb053cd8e..9713b141bc 100644 --- a/synapse/events/presence_router.py +++ b/synapse/events/presence_router.py @@ -80,7 +80,7 @@ def load_legacy_presence_router(hs: "HomeServer") -> None: # All methods that the module provides should be async, but this wasn't enforced # in the old module system, so we wrap them if needed def async_wrapper( - f: Optional[Callable[P, R]] + f: Optional[Callable[P, R]], ) -> Optional[Callable[P, Awaitable[R]]]: # f might be None if the callback isn't implemented by the module. In this # case we don't want to register a callback at all so we return None. diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 6b70ea94d1..dd21a6136b 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -504,7 +504,7 @@ class UnpersistedEventContext(UnpersistedEventContextBase): def _encode_state_group_delta( - state_group_delta: Dict[Tuple[int, int], StateMap[str]] + state_group_delta: Dict[Tuple[int, int], StateMap[str]], ) -> List[Tuple[int, int, Optional[List[Tuple[str, str, str]]]]]: if not state_group_delta: return [] @@ -517,7 +517,7 @@ def _encode_state_group_delta( def _decode_state_group_delta( - input: List[Tuple[int, int, List[Tuple[str, str, str]]]] + input: List[Tuple[int, int, List[Tuple[str, str, str]]]], ) -> Dict[Tuple[int, int], StateMap[str]]: if not input: return {} @@ -544,7 +544,7 @@ def _encode_state_dict( def _decode_state_dict( - input: Optional[List[Tuple[str, str, str]]] + input: Optional[List[Tuple[str, str, str]]], ) -> Optional[StateMap[str]]: """Decodes a state dict encoded using `_encode_state_dict` above""" if input is None: diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py index a571eff590..61e28bff66 100644 --- a/synapse/federation/__init__.py +++ b/synapse/federation/__init__.py @@ -19,5 +19,4 @@ # # -""" This package includes all the federation specific logic. -""" +"""This package includes all the federation specific logic.""" diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 0bfde00315..8340b48503 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -20,7 +20,7 @@ # # -""" This module contains all the persistence actions done by the federation +"""This module contains all the persistence actions done by the federation package. These actions are mostly only used by the :py:mod:`.replication` module. diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 20f87c885e..a05e5d5319 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -859,7 +859,6 @@ class FederationMediaThumbnailServlet(BaseFederationServerServlet): request: SynapseRequest, media_id: str, ) -> None: - width = parse_integer(request, "width", required=True) height = parse_integer(request, "height", required=True) method = parse_string(request, "method", "scale") diff --git a/synapse/federation/units.py b/synapse/federation/units.py index b2c8ba5887..d8b67a6a5b 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -19,7 +19,7 @@ # # -""" Defines the JSON structure of the protocol units used by the server to +"""Defines the JSON structure of the protocol units used by the server to server protocol. """ diff --git a/synapse/handlers/account.py b/synapse/handlers/account.py index 89e944bc17..37cc3d3ff5 100644 --- a/synapse/handlers/account.py +++ b/synapse/handlers/account.py @@ -118,10 +118,10 @@ class AccountHandler: } if self._use_account_validity_in_account_status: - status["org.matrix.expired"] = ( - await self._account_validity_handler.is_user_expired( - user_id.to_string() - ) + status[ + "org.matrix.expired" + ] = await self._account_validity_handler.is_user_expired( + user_id.to_string() ) return status diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index b44e862493..c874d22eac 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -197,14 +197,15 @@ class AdminHandler: # events that we have and then filtering, this isn't the most # efficient method perhaps but it does guarantee we get everything. while True: - events, _ = ( - await self._store.paginate_room_events_by_topological_ordering( - room_id=room_id, - from_key=from_key, - to_key=to_key, - limit=100, - direction=Direction.FORWARDS, - ) + ( + events, + _, + ) = await self._store.paginate_room_events_by_topological_ordering( + room_id=room_id, + from_key=from_key, + to_key=to_key, + limit=100, + direction=Direction.FORWARDS, ) if not events: break diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index a1fab99f6b..1f4264ad7e 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -166,8 +166,7 @@ def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]: if "country" not in identifier or ( # The specification requires a "phone" field, while Synapse used to require a "number" # field. Accept both for backwards compatibility. - "phone" not in identifier - and "number" not in identifier + "phone" not in identifier and "number" not in identifier ): raise SynapseError( 400, "Invalid phone-type identifier", errcode=Codes.INVALID_PARAM diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index ad2b0f5fcc..62ce16794f 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -265,9 +265,9 @@ class DirectoryHandler: async def get_association(self, room_alias: RoomAlias) -> JsonDict: room_id = None if self.hs.is_mine(room_alias): - result: Optional[RoomAliasMapping] = ( - await self.get_association_from_room_alias(room_alias) - ) + result: Optional[ + RoomAliasMapping + ] = await self.get_association_from_room_alias(room_alias) if result: room_id = result.room_id @@ -512,11 +512,9 @@ class DirectoryHandler: raise SynapseError(403, "Not allowed to publish room") # Check if publishing is blocked by a third party module - allowed_by_third_party_rules = ( - await ( - self._third_party_event_rules.check_visibility_can_be_modified( - room_id, visibility - ) + allowed_by_third_party_rules = await ( + self._third_party_event_rules.check_visibility_can_be_modified( + room_id, visibility ) ) if not allowed_by_third_party_rules: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 299588e476..2b7aad5b58 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1001,11 +1001,11 @@ class FederationHandler: ) if include_auth_user_id: - event_content[EventContentFields.AUTHORISING_USER] = ( - await self._event_auth_handler.get_user_which_could_invite( - room_id, - state_ids, - ) + event_content[ + EventContentFields.AUTHORISING_USER + ] = await self._event_auth_handler.get_user_which_could_invite( + room_id, + state_ids, ) builder = self.event_builder_factory.for_room_version( diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index cb31d65aa9..89191217d6 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -21,6 +21,7 @@ # """Utilities for interacting with Identity Servers""" + import logging import urllib.parse from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 5aa48230ec..204965afee 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1225,10 +1225,9 @@ class EventCreationHandler: ) if prev_event_ids is not None: - assert ( - len(prev_event_ids) <= 10 - ), "Attempting to create an event with %i prev_events" % ( - len(prev_event_ids), + assert len(prev_event_ids) <= 10, ( + "Attempting to create an event with %i prev_events" + % (len(prev_event_ids),) ) else: prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 6fd7afa280..3c44458fa3 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -507,15 +507,16 @@ class PaginationHandler: # Initially fetch the events from the database. With any luck, we can return # these without blocking on backfill (handled below). - events, next_key = ( - await self.store.paginate_room_events_by_topological_ordering( - room_id=room_id, - from_key=from_token.room_key, - to_key=to_room_key, - direction=pagin_config.direction, - limit=pagin_config.limit, - event_filter=event_filter, - ) + ( + events, + next_key, + ) = await self.store.paginate_room_events_by_topological_ordering( + room_id=room_id, + from_key=from_token.room_key, + to_key=to_room_key, + direction=pagin_config.direction, + limit=pagin_config.limit, + event_filter=event_filter, ) if pagin_config.direction == Direction.BACKWARDS: @@ -584,15 +585,16 @@ class PaginationHandler: # If we did backfill something, refetch the events from the database to # catch anything new that might have been added since we last fetched. if did_backfill: - events, next_key = ( - await self.store.paginate_room_events_by_topological_ordering( - room_id=room_id, - from_key=from_token.room_key, - to_key=to_room_key, - direction=pagin_config.direction, - limit=pagin_config.limit, - event_filter=event_filter, - ) + ( + events, + next_key, + ) = await self.store.paginate_room_events_by_topological_ordering( + room_id=room_id, + from_key=from_token.room_key, + to_key=to_room_key, + direction=pagin_config.direction, + limit=pagin_config.limit, + event_filter=event_filter, ) else: # Otherwise, we can backfill in the background for eventual diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 37ee625f71..390cafa8f6 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -71,6 +71,7 @@ user state; this device follows the normal timeout logic (see above) and will automatically be replaced with any information from currently available devices. """ + import abc import contextlib import itertools @@ -493,9 +494,9 @@ class WorkerPresenceHandler(BasePresenceHandler): # The number of ongoing syncs on this process, by (user ID, device ID). # Empty if _presence_enabled is false. - self._user_device_to_num_current_syncs: Dict[Tuple[str, Optional[str]], int] = ( - {} - ) + self._user_device_to_num_current_syncs: Dict[ + Tuple[str, Optional[str]], int + ] = {} self.notifier = hs.get_notifier() self.instance_id = hs.get_instance_id() @@ -818,9 +819,9 @@ class PresenceHandler(BasePresenceHandler): # Keeps track of the number of *ongoing* syncs on this process. While # this is non zero a user will never go offline. - self._user_device_to_num_current_syncs: Dict[Tuple[str, Optional[str]], int] = ( - {} - ) + self._user_device_to_num_current_syncs: Dict[ + Tuple[str, Optional[str]], int + ] = {} # Keeps track of the number of *ongoing* syncs on other processes. # diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index af8cd838ee..ac4544ca4c 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -351,9 +351,9 @@ class ProfileHandler: server_name = host if self._is_mine_server_name(server_name): - media_info: Optional[Union[LocalMedia, RemoteMedia]] = ( - await self.store.get_local_media(media_id) - ) + media_info: Optional[ + Union[LocalMedia, RemoteMedia] + ] = await self.store.get_local_media(media_id) else: media_info = await self.store.get_cached_remote_media(server_name, media_id) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index efe31e81f9..b1158ee77d 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -188,13 +188,13 @@ class RelationsHandler: if include_original_event: # Do not bundle aggregations when retrieving the original event because # we want the content before relations are applied to it. - return_value["original_event"] = ( - await self._event_serializer.serialize_event( - event, - now, - bundle_aggregations=None, - config=serialize_options, - ) + return_value[ + "original_event" + ] = await self._event_serializer.serialize_event( + event, + now, + bundle_aggregations=None, + config=serialize_options, ) if next_token: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 2c6e672ede..35c88f1b91 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -20,6 +20,7 @@ # """Contains functions for performing actions on rooms.""" + import itertools import logging import math @@ -900,11 +901,9 @@ class RoomCreationHandler: ) # Check whether this visibility value is blocked by a third party module - allowed_by_third_party_rules = ( - await ( - self._third_party_event_rules.check_visibility_can_be_modified( - room_id, visibility - ) + allowed_by_third_party_rules = await ( + self._third_party_event_rules.check_visibility_can_be_modified( + room_id, visibility ) ) if not allowed_by_third_party_rules: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 51b9772329..75c60e3c34 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1302,11 +1302,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # If this is going to be a local join, additional information must # be included in the event content in order to efficiently validate # the event. - content[EventContentFields.AUTHORISING_USER] = ( - await self.event_auth_handler.get_user_which_could_invite( - room_id, - state_before_join, - ) + content[ + EventContentFields.AUTHORISING_USER + ] = await self.event_auth_handler.get_user_which_could_invite( + room_id, + state_before_join, ) return False, [] @@ -1415,9 +1415,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if requester is not None: sender = UserID.from_string(event.sender) - assert ( - sender == requester.user - ), "Sender (%s) must be same as requester (%s)" % (sender, requester.user) + assert sender == requester.user, ( + "Sender (%s) must be same as requester (%s)" % (sender, requester.user) + ) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) else: requester = types.create_requester(target_user) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index a7d52fa648..1a71135d5f 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -423,9 +423,9 @@ class SearchHandler: } if search_result.room_groups and "room_id" in group_keys: - rooms_cat_res.setdefault("groups", {})[ - "room_id" - ] = search_result.room_groups + rooms_cat_res.setdefault("groups", {})["room_id"] = ( + search_result.room_groups + ) if sender_group and "sender" in group_keys: rooms_cat_res.setdefault("groups", {})["sender"] = sender_group diff --git a/synapse/handlers/sliding_sync/__init__.py b/synapse/handlers/sliding_sync/__init__.py index d92bdad307..f79796a336 100644 --- a/synapse/handlers/sliding_sync/__init__.py +++ b/synapse/handlers/sliding_sync/__init__.py @@ -587,9 +587,7 @@ class SlidingSyncHandler: Membership.LEAVE, Membership.BAN, ): - to_bound = ( - room_membership_for_user_at_to_token.event_pos.to_room_stream_token() - ) + to_bound = room_membership_for_user_at_to_token.event_pos.to_room_stream_token() timeline_from_bound = from_bound if ignore_timeline_bound: diff --git a/synapse/handlers/sliding_sync/extensions.py b/synapse/handlers/sliding_sync/extensions.py index d9f4c56e6e..6f37cc3462 100644 --- a/synapse/handlers/sliding_sync/extensions.py +++ b/synapse/handlers/sliding_sync/extensions.py @@ -386,9 +386,9 @@ class SlidingSyncExtensionHandler: if have_push_rules_changed: global_account_data_map = dict(global_account_data_map) # TODO: This should take into account the `from_token` and `to_token` - global_account_data_map[AccountDataTypes.PUSH_RULES] = ( - await self.push_rules_handler.push_rules_for_user(sync_config.user) - ) + global_account_data_map[ + AccountDataTypes.PUSH_RULES + ] = await self.push_rules_handler.push_rules_for_user(sync_config.user) else: # TODO: This should take into account the `to_token` all_global_account_data = await self.store.get_global_account_data_for_user( @@ -397,9 +397,9 @@ class SlidingSyncExtensionHandler: global_account_data_map = dict(all_global_account_data) # TODO: This should take into account the `to_token` - global_account_data_map[AccountDataTypes.PUSH_RULES] = ( - await self.push_rules_handler.push_rules_for_user(sync_config.user) - ) + global_account_data_map[ + AccountDataTypes.PUSH_RULES + ] = await self.push_rules_handler.push_rules_for_user(sync_config.user) # Fetch room account data account_data_by_room_map: Mapping[str, Mapping[str, JsonMapping]] = {} diff --git a/synapse/handlers/sliding_sync/room_lists.py b/synapse/handlers/sliding_sync/room_lists.py index 12b7958c6f..1423d6ca53 100644 --- a/synapse/handlers/sliding_sync/room_lists.py +++ b/synapse/handlers/sliding_sync/room_lists.py @@ -293,10 +293,11 @@ class SlidingSyncRoomLists: is_encrypted=is_encrypted, ) - newly_joined_room_ids, newly_left_room_map = ( - await self._get_newly_joined_and_left_rooms( - user_id, from_token=from_token, to_token=to_token - ) + ( + newly_joined_room_ids, + newly_left_room_map, + ) = await self._get_newly_joined_and_left_rooms( + user_id, from_token=from_token, to_token=to_token ) dm_room_ids = await self._get_dm_rooms_for_user(user_id) @@ -958,10 +959,11 @@ class SlidingSyncRoomLists: else: rooms_for_user[room_id] = change_room_for_user - newly_joined_room_ids, newly_left_room_ids = ( - await self._get_newly_joined_and_left_rooms( - user_id, to_token=to_token, from_token=from_token - ) + ( + newly_joined_room_ids, + newly_left_room_ids, + ) = await self._get_newly_joined_and_left_rooms( + user_id, to_token=to_token, from_token=from_token ) dm_room_ids = await self._get_dm_rooms_for_user(user_id) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c44baa7042..609840bfe9 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -183,10 +183,7 @@ class JoinedSyncResult: to tell if room needs to be part of the sync result. """ return bool( - self.timeline - or self.state - or self.ephemeral - or self.account_data + self.timeline or self.state or self.ephemeral or self.account_data # nb the notification count does not, er, count: if there's nothing # else in the result, we don't need to send it. ) @@ -575,10 +572,10 @@ class SyncHandler: if timeout == 0 or since_token is None or full_state: # we are going to return immediately, so don't bother calling # notifier.wait_for_events. - result: Union[SyncResult, E2eeSyncResult] = ( - await self.current_sync_for_user( - sync_config, sync_version, since_token, full_state=full_state - ) + result: Union[ + SyncResult, E2eeSyncResult + ] = await self.current_sync_for_user( + sync_config, sync_version, since_token, full_state=full_state ) else: # Otherwise, we wait for something to happen and report it to the user. @@ -673,10 +670,10 @@ class SyncHandler: # Go through the `/sync` v2 path if sync_version == SyncVersion.SYNC_V2: - sync_result: Union[SyncResult, E2eeSyncResult] = ( - await self.generate_sync_result( - sync_config, since_token, full_state - ) + sync_result: Union[ + SyncResult, E2eeSyncResult + ] = await self.generate_sync_result( + sync_config, since_token, full_state ) # Go through the MSC3575 Sliding Sync `/sync/e2ee` path elif sync_version == SyncVersion.E2EE_SYNC: @@ -1488,13 +1485,16 @@ class SyncHandler: # timeline here. The caller will then dedupe any redundant # ones. - state_ids = await self._state_storage_controller.get_state_ids_for_event( - batch.events[0].event_id, - # we only want members! - state_filter=StateFilter.from_types( - (EventTypes.Member, member) for member in members_to_fetch - ), - await_full_state=False, + state_ids = ( + await self._state_storage_controller.get_state_ids_for_event( + batch.events[0].event_id, + # we only want members! + state_filter=StateFilter.from_types( + (EventTypes.Member, member) + for member in members_to_fetch + ), + await_full_state=False, + ) ) return state_ids @@ -2166,18 +2166,18 @@ class SyncHandler: if push_rules_changed: global_account_data = dict(global_account_data) - global_account_data[AccountDataTypes.PUSH_RULES] = ( - await self._push_rules_handler.push_rules_for_user(sync_config.user) - ) + global_account_data[ + AccountDataTypes.PUSH_RULES + ] = await self._push_rules_handler.push_rules_for_user(sync_config.user) else: all_global_account_data = await self.store.get_global_account_data_for_user( user_id ) global_account_data = dict(all_global_account_data) - global_account_data[AccountDataTypes.PUSH_RULES] = ( - await self._push_rules_handler.push_rules_for_user(sync_config.user) - ) + global_account_data[ + AccountDataTypes.PUSH_RULES + ] = await self._push_rules_handler.push_rules_for_user(sync_config.user) account_data_for_user = ( await sync_config.filter_collection.filter_global_account_data( diff --git a/synapse/handlers/worker_lock.py b/synapse/handlers/worker_lock.py index 7e578cf462..db998f6701 100644 --- a/synapse/handlers/worker_lock.py +++ b/synapse/handlers/worker_lock.py @@ -183,7 +183,7 @@ class WorkerLocksHandler: return def _wake_all_locks( - locks: Collection[Union[WaitingLock, WaitingMultiLock]] + locks: Collection[Union[WaitingLock, WaitingMultiLock]], ) -> None: for lock in locks: deferred = lock.deferred diff --git a/synapse/http/client.py b/synapse/http/client.py index cb4f72d771..143fee9796 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -1313,6 +1313,5 @@ def is_unknown_endpoint( ) ) or ( # Older Synapses returned a 400 error. - e.code == 400 - and synapse_error.errcode == Codes.UNRECOGNIZED + e.code == 400 and synapse_error.errcode == Codes.UNRECOGNIZED ) diff --git a/synapse/http/server.py b/synapse/http/server.py index 211795dc39..3e2d94d399 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -233,7 +233,7 @@ def return_html_error( def wrap_async_request_handler( - h: Callable[["_AsyncResource", "SynapseRequest"], Awaitable[None]] + h: Callable[["_AsyncResource", "SynapseRequest"], Awaitable[None]], ) -> Callable[["_AsyncResource", "SynapseRequest"], "defer.Deferred[None]"]: """Wraps an async request handler so that it calls request.processing. diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py index 6a6afbfc0b..d9ff70b252 100644 --- a/synapse/logging/_terse_json.py +++ b/synapse/logging/_terse_json.py @@ -22,6 +22,7 @@ """ Log formatters that output terse JSON. """ + import json import logging diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 4650b60962..ae2b3d11c0 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -20,7 +20,7 @@ # # -""" Thread-local-alike tracking of log contexts within synapse +"""Thread-local-alike tracking of log contexts within synapse This module provides objects and utilities for tracking contexts through synapse code, so that log lines can include a request identifier, and so that @@ -29,6 +29,7 @@ them. See doc/log_contexts.rst for details on how this works. """ + import logging import threading import typing @@ -751,7 +752,7 @@ def preserve_fn( f: Union[ Callable[P, R], Callable[P, Awaitable[R]], - ] + ], ) -> Callable[P, "defer.Deferred[R]"]: """Function decorator which wraps the function with run_in_background""" diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index e32b3f6781..d976e58e49 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -169,6 +169,7 @@ Gotchas than one caller? Will all of those calling functions have be in a context with an active span? """ + import contextlib import enum import inspect @@ -414,7 +415,7 @@ def ensure_active_span( """ def ensure_active_span_inner_1( - func: Callable[P, R] + func: Callable[P, R], ) -> Callable[P, Union[Optional[T], R]]: @wraps(func) def ensure_active_span_inner_2( @@ -700,7 +701,7 @@ def set_operation_name(operation_name: str) -> None: @only_if_tracing def force_tracing( - span: Union["opentracing.Span", _Sentinel] = _Sentinel.sentinel + span: Union["opentracing.Span", _Sentinel] = _Sentinel.sentinel, ) -> None: """Force sampling for the active/given span and its children. @@ -1093,9 +1094,10 @@ def trace_servlet( # Mypy seems to think that start_context.tag below can be Optional[str], but # that doesn't appear to be correct and works in practice. - request_tags[ - SynapseTags.REQUEST_TAG - ] = request.request_metrics.start_context.tag # type: ignore[assignment] + + request_tags[SynapseTags.REQUEST_TAG] = ( + request.request_metrics.start_context.tag # type: ignore[assignment] + ) # set the tags *after* the servlet completes, in case it decided to # prioritise the span (tags will get dropped on unprioritised spans) diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 19c92b02a0..49d0ff9fc1 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -293,7 +293,7 @@ def wrap_as_background_process( """ def wrap_as_background_process_inner( - func: Callable[P, Awaitable[Optional[R]]] + func: Callable[P, Awaitable[Optional[R]]], ) -> Callable[P, "defer.Deferred[Optional[R]]"]: @wraps(func) def wrap_as_background_process_inner_2( diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 34ab637c3d..679cbe9afa 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -304,9 +304,9 @@ class BulkPushRuleEvaluator: if relation_type == "m.thread" and event.content.get( "m.relates_to", {} ).get("is_falling_back", False): - related_events["m.in_reply_to"][ - "im.vector.is_falling_back" - ] = "" + related_events["m.in_reply_to"]["im.vector.is_falling_back"] = ( + "" + ) return related_events @@ -372,7 +372,8 @@ class BulkPushRuleEvaluator: gather_results( ( run_in_background( # type: ignore[call-arg] - self.store.get_number_joined_users_in_room, event.room_id # type: ignore[arg-type] + self.store.get_number_joined_users_in_room, + event.room_id, # type: ignore[arg-type] ), run_in_background( self._get_power_levels_and_sender_level, diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 9c537427df..940f418396 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -119,7 +119,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): return payload - async def _handle_request(self, request: Request, content: JsonDict) -> Tuple[int, JsonDict]: # type: ignore[override] + async def _handle_request( # type: ignore[override] + self, request: Request, content: JsonDict + ) -> Tuple[int, JsonDict]: with Measure(self.clock, "repl_fed_send_events_parse"): room_id = content["room_id"] backfilled = content["backfilled"] diff --git a/synapse/replication/http/push.py b/synapse/replication/http/push.py index de07e75b46..2e06c43ce5 100644 --- a/synapse/replication/http/push.py +++ b/synapse/replication/http/push.py @@ -98,7 +98,9 @@ class ReplicationCopyPusherRestServlet(ReplicationEndpoint): self._store = hs.get_datastores().main @staticmethod - async def _serialize_payload(user_id: str, old_room_id: str, new_room_id: str) -> JsonDict: # type: ignore[override] + async def _serialize_payload( # type: ignore[override] + user_id: str, old_room_id: str, new_room_id: str + ) -> JsonDict: return {} async def _handle_request( # type: ignore[override] @@ -109,7 +111,6 @@ class ReplicationCopyPusherRestServlet(ReplicationEndpoint): old_room_id: str, new_room_id: str, ) -> Tuple[int, JsonDict]: - await self._store.copy_push_rules_from_room_to_room_for_user( old_room_id, new_room_id, user_id ) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 3dddbb70b4..0bd5478cd3 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -18,8 +18,8 @@ # [This file includes modifications made by New Vector Limited] # # -"""A replication client for use by synapse workers. -""" +"""A replication client for use by synapse workers.""" + import logging from typing import TYPE_CHECKING, Dict, Iterable, Optional, Set, Tuple diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index b7a7e77597..7d51441e91 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -23,6 +23,7 @@ The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are allowed to be sent by which side. """ + import abc import logging from typing import List, Optional, Tuple, Type, TypeVar diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 72a42cb6cc..6101226938 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -857,7 +857,7 @@ UpdateRow = TypeVar("UpdateRow") def _batch_updates( - updates: Iterable[Tuple[UpdateToken, UpdateRow]] + updates: Iterable[Tuple[UpdateToken, UpdateRow]], ) -> Iterator[Tuple[UpdateToken, List[UpdateRow]]]: """Collect stream updates with the same token together diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 4471cc8f0c..fb9c539122 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -23,6 +23,7 @@ protocols. An explanation of this protocol is available in docs/tcp_replication.md """ + import fcntl import logging import struct diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index c0329378ac..d647a2b332 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -18,8 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -"""The server side of the replication stream. -""" +"""The server side of the replication stream.""" import logging import random @@ -307,7 +306,7 @@ class ReplicationStreamer: def _batch_updates( - updates: List[Tuple[Token, StreamRow]] + updates: List[Tuple[Token, StreamRow]], ) -> List[Tuple[Optional[Token], StreamRow]]: """Takes a list of updates of form [(token, row)] and sets the token to None for all rows where the next row has the same token. This is used to diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index d021904de7..ebf5964d29 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -247,7 +247,7 @@ class _StreamFromIdGen(Stream): def current_token_without_instance( - current_token: Callable[[], int] + current_token: Callable[[], int], ) -> Callable[[str], int]: """Takes a current token callback function for a single writer stream that doesn't take an instance name parameter and wraps it in a function that diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index 0867f7a51c..bec2331590 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -181,8 +181,7 @@ class NewRegistrationTokenRestServlet(RestServlet): uses_allowed = body.get("uses_allowed", None) if not ( - uses_allowed is None - or (type(uses_allowed) is int and uses_allowed >= 0) # noqa: E721 + uses_allowed is None or (type(uses_allowed) is int and uses_allowed >= 0) # noqa: E721 ): raise SynapseError( HTTPStatus.BAD_REQUEST, diff --git a/synapse/rest/client/_base.py b/synapse/rest/client/_base.py index 93dec6375a..6cf37869d8 100644 --- a/synapse/rest/client/_base.py +++ b/synapse/rest/client/_base.py @@ -19,8 +19,8 @@ # # -"""This module contains base REST classes for constructing client v1 servlets. -""" +"""This module contains base REST classes for constructing client v1 servlets.""" + import logging import re from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar, cast diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py index 0ee24081fa..734c9e992f 100644 --- a/synapse/rest/client/account_data.py +++ b/synapse/rest/client/account_data.py @@ -108,9 +108,9 @@ class AccountDataServlet(RestServlet): # Push rules are stored in a separate table and must be queried separately. if account_data_type == AccountDataTypes.PUSH_RULES: - account_data: Optional[JsonMapping] = ( - await self._push_rules_handler.push_rules_for_user(requester.user) - ) + account_data: Optional[ + JsonMapping + ] = await self._push_rules_handler.push_rules_for_user(requester.user) else: account_data = await self.store.get_global_account_data_by_type_for_user( user_id, account_data_type diff --git a/synapse/rest/client/account_validity.py b/synapse/rest/client/account_validity.py index 6222a5cc37..ec7836b647 100644 --- a/synapse/rest/client/account_validity.py +++ b/synapse/rest/client/account_validity.py @@ -48,9 +48,7 @@ class AccountValidityRenewServlet(RestServlet): self.account_renewed_template = ( hs.config.account_validity.account_validity_account_renewed_template ) - self.account_previously_renewed_template = ( - hs.config.account_validity.account_validity_account_previously_renewed_template - ) + self.account_previously_renewed_template = hs.config.account_validity.account_validity_account_previously_renewed_template self.invalid_token_template = ( hs.config.account_validity.account_validity_invalid_token_template ) diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py index 613890061e..ad23cc76ce 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py @@ -20,6 +20,7 @@ # """This module contains REST servlets to do with event streaming, /events.""" + import logging from typing import TYPE_CHECKING, Dict, List, Tuple, Union diff --git a/synapse/rest/client/presence.py b/synapse/rest/client/presence.py index 572e92642c..ecc52956e4 100644 --- a/synapse/rest/client/presence.py +++ b/synapse/rest/client/presence.py @@ -19,8 +19,8 @@ # # -""" This module contains REST servlets to do with presence: /presence/ -""" +"""This module contains REST servlets to do with presence: /presence/""" + import logging from typing import TYPE_CHECKING, Tuple diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py index c1a80c5c3d..7a95b9445d 100644 --- a/synapse/rest/client/profile.py +++ b/synapse/rest/client/profile.py @@ -19,7 +19,7 @@ # # -""" This module contains REST servlets to do with profile: /profile/ """ +"""This module contains REST servlets to do with profile: /profile/""" from http import HTTPStatus from typing import TYPE_CHECKING, Tuple diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 5dddbc69be..61e1436841 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -640,12 +640,10 @@ class RegisterRestServlet(RestServlet): if not password_hash: raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) - desired_username = ( - await ( - self.password_auth_provider.get_username_for_registration( - auth_result, - params, - ) + desired_username = await ( + self.password_auth_provider.get_username_for_registration( + auth_result, + params, ) ) @@ -696,11 +694,9 @@ class RegisterRestServlet(RestServlet): session_id ) - display_name = ( - await ( - self.password_auth_provider.get_displayname_for_registration( - auth_result, params - ) + display_name = await ( + self.password_auth_provider.get_displayname_for_registration( + auth_result, params ) ) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 7d57904d69..83f84e4998 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -19,7 +19,8 @@ # # -""" This module contains REST servlets to do with rooms: /rooms/ """ +"""This module contains REST servlets to do with rooms: /rooms/""" + import logging import re from enum import Enum diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 22c85e497a..cc9fbfe546 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -1045,9 +1045,9 @@ class SlidingSyncRestServlet(RestServlet): serialized_rooms[room_id]["initial"] = room_result.initial if room_result.unstable_expanded_timeline: - serialized_rooms[room_id][ - "unstable_expanded_timeline" - ] = room_result.unstable_expanded_timeline + serialized_rooms[room_id]["unstable_expanded_timeline"] = ( + room_result.unstable_expanded_timeline + ) # This will be omitted for invite/knock rooms with `stripped_state` if ( @@ -1082,9 +1082,9 @@ class SlidingSyncRestServlet(RestServlet): # This will be omitted for invite/knock rooms with `stripped_state` if room_result.prev_batch is not None: - serialized_rooms[room_id]["prev_batch"] = ( - await room_result.prev_batch.to_string(self.store) - ) + serialized_rooms[room_id][ + "prev_batch" + ] = await room_result.prev_batch.to_string(self.store) # This will be omitted for invite/knock rooms with `stripped_state` if room_result.num_live is not None: diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 30c1f17fc6..f791904168 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -21,6 +21,7 @@ """This module contains logic for storing HTTP PUT transactions. This is used to ensure idempotency when performing PUTs using the REST API.""" + import logging from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Hashable, Tuple diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 1975ebb477..3c2028a2ad 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -191,10 +191,10 @@ class RemoteKey(RestServlet): server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {} for server_name, key_ids in query.items(): if key_ids: - results: Mapping[str, Optional[FetchKeyResultForRemote]] = ( - await self.store.get_server_keys_json_for_remote( - server_name, key_ids - ) + results: Mapping[ + str, Optional[FetchKeyResultForRemote] + ] = await self.store.get_server_keys_json_for_remote( + server_name, key_ids ) else: results = await self.store.get_all_server_keys_json_for_remote( diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index 989e570671..d336d60c93 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -65,9 +65,9 @@ class WellKnownBuilder: } account_management_url = await auth.account_management_url() if account_management_url is not None: - result["org.matrix.msc2965.authentication"][ - "account" - ] = account_management_url + result["org.matrix.msc2965.authentication"]["account"] = ( + account_management_url + ) if self._config.server.extra_well_known_client_content: for ( diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index f6ea90bd4f..e88e8c9b45 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -119,7 +119,9 @@ class ResourceLimitsServerNotices: elif not currently_blocked and limit_msg: # Room is not notifying of a block, when it ought to be. await self._apply_limit_block_notification( - user_id, limit_msg, limit_type # type: ignore + user_id, + limit_msg, + limit_type, # type: ignore ) except SynapseError as e: logger.error("Error sending resource limits server notice: %s", e) diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index ac0919340b..879ee9039e 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -416,7 +416,7 @@ class EventsPersistenceStorageController: set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled)) async def enqueue( - item: Tuple[str, List[Tuple[EventBase, EventContext]]] + item: Tuple[str, List[Tuple[EventBase, EventContext]]], ) -> Dict[str, str]: room_id, evs_ctxs = item return await self._event_persist_queue.add_to_queue( @@ -792,9 +792,9 @@ class EventsPersistenceStorageController: ) # Remove any events which are prev_events of any existing events. - existing_prevs: Collection[str] = ( - await self.persist_events_store._get_events_which_are_prevs(result) - ) + existing_prevs: Collection[ + str + ] = await self.persist_events_store._get_events_which_are_prevs(result) result.difference_update(existing_prevs) # Finally handle the case where the new events have soft-failed prev diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 4b66247640..bf6cfcbfd9 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -238,9 +238,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): INNER JOIN user_ips USING (user_id, access_token, ip) GROUP BY user_id, access_token, ip HAVING count(*) > 1 - """.format( - clause - ), + """.format(clause), args, ) res = cast( @@ -373,9 +371,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): LIMIT ? ) c INNER JOIN user_ips AS u USING (user_id, device_id, last_seen) - """ % { - "where_clause": where_clause - } + """ % {"where_clause": where_clause} txn.execute(sql, where_args + [batch_size]) rows = cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 042d595ea0..0612b82b9b 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -1116,7 +1116,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): txn.execute(sql, (start, stop)) - destinations = {d for d, in txn} + destinations = {d for (d,) in txn} to_remove = set() for d in destinations: try: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 53024bddc3..a83df4075a 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -670,9 +670,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): result["keys"] = keys device_display_name = None - if ( - self.hs.config.federation.allow_device_name_lookup_over_federation - ): + if self.hs.config.federation.allow_device_name_lookup_over_federation: device_display_name = device.display_name if device_display_name: result["device_display_name"] = device_display_name @@ -917,7 +915,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): from_key, to_key, ) - return {u for u, in rows} + return {u for (u,) in rows} @cancellable async def get_users_whose_devices_changed( @@ -968,7 +966,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): txn.database_engine, "user_id", chunk ) txn.execute(sql % (clause,), [from_key, to_key] + args) - changes.update(user_id for user_id, in txn) + changes.update(user_id for (user_id,) in txn) return changes @@ -1520,7 +1518,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): args: List[Any], ) -> Set[str]: txn.execute(sql.format(clause=clause), args) - return {user_id for user_id, in txn} + return {user_id for (user_id,) in txn} changes = set() for chunk in batch_iter(changed_room_ids, 1000): @@ -1560,7 +1558,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): txn: LoggingTransaction, ) -> Set[str]: txn.execute(sql, (from_id, to_id)) - return {room_id for room_id, in txn} + return {room_id for (room_id,) in txn} return await self.db_pool.runInteraction( "get_all_device_list_changes", diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 4d6a921ab2..c2c93e12d9 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -387,9 +387,7 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore): is_verified, session_data FROM e2e_room_keys WHERE user_id = ? AND version = ? AND (%s) - """ % ( - " OR ".join(where_clauses) - ) + """ % (" OR ".join(where_clauses)) txn.execute(sql, params) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 9e6c9561ae..575aaf498b 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -472,9 +472,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker signature_sql = """ SELECT user_id, key_id, target_device_id, signature FROM e2e_cross_signing_signatures WHERE %s - """ % ( - " OR ".join("(" + q + ")" for q in signature_query_clauses) - ) + """ % (" OR ".join("(" + q + ")" for q in signature_query_clauses)) txn.execute(signature_sql, signature_query_params) return cast( @@ -917,9 +915,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker FROM e2e_cross_signing_keys WHERE %(clause)s ORDER BY user_id, keytype, stream_id DESC - """ % { - "clause": clause - } + """ % {"clause": clause} else: # SQLite has special handling for bare columns when using # MIN/MAX with a `GROUP BY` clause where it picks the value from @@ -929,9 +925,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker FROM e2e_cross_signing_keys WHERE %(clause)s GROUP BY user_id, keytype - """ % { - "clause": clause - } + """ % {"clause": clause} txn.execute(sql, params) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 715846865b..46aa5902d8 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -326,7 +326,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas """ rows = txn.execute_values(sql, chains.items()) - results.update(r for r, in rows) + results.update(r for (r,) in rows) else: # For SQLite we just fall back to doing a noddy for loop. sql = """ @@ -335,7 +335,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas """ for chain_id, max_no in chains.items(): txn.execute(sql, (chain_id, max_no)) - results.update(r for r, in txn) + results.update(r for (r,) in txn) return results @@ -645,7 +645,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ] rows = txn.execute_values(sql, args) - result.update(r for r, in rows) + result.update(r for (r,) in rows) else: # For SQLite we just fall back to doing a noddy for loop. sql = """ @@ -654,7 +654,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas """ for chain_id, (min_no, max_no) in chain_to_gap.items(): txn.execute(sql, (chain_id, min_no, max_no)) - result.update(r for r, in txn) + result.update(r for (r,) in txn) return result @@ -1220,13 +1220,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas HAVING count(*) > ? ORDER BY count(*) DESC LIMIT ? - """ % ( - where_clause, - ) + """ % (where_clause,) query_args = list(itertools.chain(room_id_filter, [min_count, limit])) txn.execute(sql, query_args) - return [room_id for room_id, in txn] + return [room_id for (room_id,) in txn] return await self.db_pool.runInteraction( "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn @@ -1358,7 +1356,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas def get_forward_extremeties_for_room_txn(txn: LoggingTransaction) -> List[str]: txn.execute(sql, (stream_ordering, room_id)) - return [event_id for event_id, in txn] + return [event_id for (event_id,) in txn] event_ids = await self.db_pool.runInteraction( "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 0ebf5b53d5..f42023418e 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -1860,9 +1860,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas AND epa.notif = 1 ORDER BY epa.stream_ordering DESC LIMIT ? - """ % ( - before_clause, - ) + """ % (before_clause,) txn.execute(sql, args) return cast( List[Tuple[str, str, int, int, str, bool, str, int]], txn.fetchall() diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index e44b8d8e54..d423d80efa 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -429,9 +429,7 @@ class PersistEventsStore: if event_type == EventTypes.Member and self.is_mine_id(state_key) ] - membership_snapshot_shared_insert_values: ( - SlidingSyncMembershipSnapshotSharedInsertValues - ) = {} + membership_snapshot_shared_insert_values: SlidingSyncMembershipSnapshotSharedInsertValues = {} membership_infos_to_insert_membership_snapshots: List[ SlidingSyncMembershipInfo ] = [] @@ -719,7 +717,7 @@ class PersistEventsStore: keyvalues={}, retcols=("event_id",), ) - already_persisted_events = {event_id for event_id, in rows} + already_persisted_events = {event_id for (event_id,) in rows} state_events = [ event for event in state_events @@ -1830,12 +1828,8 @@ class PersistEventsStore: if sliding_sync_table_changes.to_insert_membership_snapshots: # Update the `sliding_sync_membership_snapshots` table # - sliding_sync_snapshot_keys = ( - sliding_sync_table_changes.membership_snapshot_shared_insert_values.keys() - ) - sliding_sync_snapshot_values = ( - sliding_sync_table_changes.membership_snapshot_shared_insert_values.values() - ) + sliding_sync_snapshot_keys = sliding_sync_table_changes.membership_snapshot_shared_insert_values.keys() + sliding_sync_snapshot_values = sliding_sync_table_changes.membership_snapshot_shared_insert_values.values() # We need to insert/update regardless of whether we have # `sliding_sync_snapshot_keys` because there are other fields in the `ON # CONFLICT` upsert to run (see inherit case (explained in @@ -3361,7 +3355,7 @@ class PersistEventsStore: ) potential_backwards_extremities.difference_update( - e for e, in existing_events_outliers + e for (e,) in existing_events_outliers ) if potential_backwards_extremities: diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index b227e05773..4209100a5c 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -647,7 +647,8 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS room_ids = {row[0] for row in rows} for room_id in room_ids: txn.call_after( - self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined] + self.get_latest_event_ids_in_room.invalidate, # type: ignore[attr-defined] + (room_id,), ) self.db_pool.simple_delete_many_txn( @@ -2065,9 +2066,7 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS ) # Map of values to insert/update in the `sliding_sync_membership_snapshots` table - sliding_sync_membership_snapshots_insert_map: ( - SlidingSyncMembershipSnapshotSharedInsertValues - ) = {} + sliding_sync_membership_snapshots_insert_map: SlidingSyncMembershipSnapshotSharedInsertValues = {} if membership == Membership.JOIN: # If we're still joined, we can pull from current state. current_state_ids_map: StateMap[ @@ -2149,14 +2148,15 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS # membership (i.e. the room shouldn't disappear if your using the # `is_encrypted` filter and you leave). if membership in (Membership.LEAVE, Membership.BAN) and is_outlier: - invite_or_knock_event_id, invite_or_knock_membership = ( - await self.db_pool.runInteraction( - "sliding_sync_membership_snapshots_bg_update._find_previous_membership", - _find_previous_membership_txn, - room_id, - user_id, - membership_event_id, - ) + ( + invite_or_knock_event_id, + invite_or_knock_membership, + ) = await self.db_pool.runInteraction( + "sliding_sync_membership_snapshots_bg_update._find_previous_membership", + _find_previous_membership_txn, + room_id, + user_id, + membership_event_id, ) # Pull from the stripped state on the invite/knock event @@ -2484,9 +2484,7 @@ def _resolve_stale_data_in_sliding_sync_joined_rooms_table( "progress_json": "{}", }, ) - depends_on = ( - _BackgroundUpdates.SLIDING_SYNC_PREFILL_JOINED_ROOMS_TO_RECALCULATE_TABLE_BG_UPDATE - ) + depends_on = _BackgroundUpdates.SLIDING_SYNC_PREFILL_JOINED_ROOMS_TO_RECALCULATE_TABLE_BG_UPDATE # Now kick-off the background update to catch-up with what we missed while Synapse # was downgraded. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 1d83390827..b188f32927 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1665,7 +1665,7 @@ class EventsWorkerStore(SQLBaseStore): txn.database_engine, "e.event_id", event_ids ) txn.execute(sql + clause, args) - found_events = {eid for eid, in txn} + found_events = {eid for (eid,) in txn} # ... and then we can update the results for each key return {eid: (eid in found_events) for eid in event_ids} @@ -1864,9 +1864,9 @@ class EventsWorkerStore(SQLBaseStore): " LIMIT ?" ) txn.execute(sql, (-last_id, -current_id, instance_name, limit)) - new_event_updates: List[Tuple[int, Tuple[str, str, str, str, str, str]]] = ( - [] - ) + new_event_updates: List[ + Tuple[int, Tuple[str, str, str, str, str, str]] + ] = [] row: Tuple[int, str, str, str, str, str, str] # Type safety: iterating over `txn` yields `Tuple`, i.e. # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index fc4c286595..08244153a3 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -201,7 +201,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): txn.execute_batch( "INSERT INTO event_backward_extremities (room_id, event_id)" " VALUES (?, ?)", - [(room_id, event_id) for event_id, in new_backwards_extrems], + [(room_id, event_id) for (event_id,) in new_backwards_extrems], ) logger.info("[purge] finding state groups referenced by deleted events") @@ -215,7 +215,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): """ ) - referenced_state_groups = {sg for sg, in txn} + referenced_state_groups = {sg for (sg,) in txn} logger.info( "[purge] found %i referenced state groups", len(referenced_state_groups) ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index bf10743574..9964331510 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -762,7 +762,7 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql, args) - return [room_id for room_id, in txn] + return [room_id for (room_id,) in txn] results: List[str] = [] for batch in batch_iter(room_ids, 1000): @@ -1030,9 +1030,7 @@ class ReceiptsWorkerStore(SQLBaseStore): SELECT event_id WHERE room_id = ? AND stream_ordering IN ( SELECT max(stream_ordering) WHERE %s ) - """ % ( - clause, - ) + """ % (clause,) txn.execute(sql, [room_id] + list(args)) rows = txn.fetchall() diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index df7f8a43b7..d7cbe33411 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1250,9 +1250,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): SELECT address, session_id, medium, client_secret, last_send_attempt, validated_at FROM threepid_validation_session WHERE %s - """ % ( - " AND ".join("%s = ?" % k for k in keyvalues.keys()), - ) + """ % (" AND ".join("%s = ?" % k for k in keyvalues.keys()),) if validated is not None: sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL") diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 80a4bf95f2..68b0806041 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1608,9 +1608,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): FROM event_reports AS er JOIN room_stats_state ON room_stats_state.room_id = er.room_id {} - """.format( - where_clause - ) + """.format(where_clause) txn.execute(sql, args) count = cast(Tuple[int], txn.fetchone())[0] diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 57b9b95c28..3d834b4bf1 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -232,9 +232,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): AND m.room_id = c.room_id AND m.user_id = c.state_key WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? AND %s - """ % ( - clause, - ) + """ % (clause,) txn.execute(sql, (room_id, Membership.JOIN, *ids)) return {r[0]: ProfileInfo(display_name=r[1], avatar_url=r[2]) for r in txn} @@ -531,9 +529,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): WHERE user_id = ? AND %s - """ % ( - clause, - ) + """ % (clause,) txn.execute(sql, (user_id, *args)) results = [ @@ -813,7 +809,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): """ txn.execute(sql, (user_id, *args)) - return {u: True for u, in txn} + return {u: True for (u,) in txn} to_return = {} for batch_user_ids in batch_iter(other_user_ids, 1000): @@ -1031,7 +1027,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): AND room_id = ? """ txn.execute(sql, (room_id,)) - return {d for d, in txn} + return {d for (d,) in txn} return await self.db_pool.runInteraction( "get_current_hosts_in_room", get_current_hosts_in_room_txn @@ -1099,7 +1095,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): """ txn.execute(sql, (room_id,)) # `server_domain` will be `NULL` for malformed MXIDs with no colons. - return tuple(d for d, in txn if d is not None) + return tuple(d for (d,) in txn if d is not None) return await self.db_pool.runInteraction( "get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn @@ -1316,9 +1312,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): room_id = ? AND membership = ? AND NOT (%s) LIMIT 1 - """ % ( - clause, - ) + """ % (clause,) def _is_local_host_in_room_ignoring_users_txn( txn: LoggingTransaction, @@ -1464,10 +1458,12 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): self, progress: JsonDict, batch_size: int ) -> int: target_min_stream_id = progress.get( - "target_min_stream_id_inclusive", self._min_stream_order_on_start # type: ignore[attr-defined] + "target_min_stream_id_inclusive", + self._min_stream_order_on_start, # type: ignore[attr-defined] ) max_stream_id = progress.get( - "max_stream_id_exclusive", self._stream_order_on_start + 1 # type: ignore[attr-defined] + "max_stream_id_exclusive", + self._stream_order_on_start + 1, # type: ignore[attr-defined] ) def add_membership_profile_txn(txn: LoggingTransaction) -> int: diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 20fcfd3122..b436275f3f 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -177,9 +177,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): AND (%s) ORDER BY stream_ordering DESC LIMIT ? - """ % ( - " OR ".join("type = '%s'" % (t,) for t in TYPES), - ) + """ % (" OR ".join("type = '%s'" % (t,) for t in TYPES),) txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 62bc4600fb..c5caaf56b0 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -535,7 +535,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): desc="check_if_events_in_current_state", ) - return frozenset(event_id for event_id, in rows) + return frozenset(event_id for (event_id,) in rows) # FIXME: how should this be cached? @cancellable diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index e9f6a918c7..79c49e7fd9 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -161,7 +161,7 @@ class StatsStore(StateDeltasStore): LIMIT ? """ txn.execute(sql, (last_user_id, batch_size)) - return [r for r, in txn] + return [r for (r,) in txn] users_to_work_on = await self.db_pool.runInteraction( "_populate_stats_process_users", _get_next_batch @@ -207,7 +207,7 @@ class StatsStore(StateDeltasStore): LIMIT ? """ txn.execute(sql, (last_room_id, batch_size)) - return [r for r, in txn] + return [r for (r,) in txn] rooms_to_work_on = await self.db_pool.runInteraction( "populate_stats_rooms_get_batch", _get_next_batch @@ -751,9 +751,7 @@ class StatsStore(StateDeltasStore): LEFT JOIN profiles AS p ON lmr.user_id = p.full_user_id {} GROUP BY lmr.user_id, displayname - """.format( - where_clause - ) + """.format(where_clause) # SQLite does not support SELECT COUNT(*) OVER() sql = """ diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 1a59e0b5a8..68d4168621 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -21,7 +21,7 @@ # # -""" This module is responsible for getting events from the DB for pagination +"""This module is responsible for getting events from the DB for pagination and event streaming. The order it returns events in depend on whether we are streaming forwards or @@ -1122,9 +1122,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): AND e.stream_ordering > ? AND e.stream_ordering <= ? %s ORDER BY e.stream_ordering ASC - """ % ( - ignore_room_clause, - ) + """ % (ignore_room_clause,) txn.execute(sql, args) diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 6e18f714d7..51cffb0986 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -224,9 +224,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): SELECT room_id, events FROM %s ORDER BY events DESC LIMIT 250 - """ % ( - TEMP_TABLE + "_rooms", - ) + """ % (TEMP_TABLE + "_rooms",) txn.execute(sql) rooms_to_work_on = cast(List[Tuple[str, int]], txn.fetchall()) diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index d4ac74c1ee..aea71b8fcc 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -767,7 +767,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): remaining_state_groups = { state_group - for state_group, in rows + for (state_group,) in rows if state_group not in state_groups_to_delete } diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index aaffe5ecc9..bf087702ea 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -607,7 +607,7 @@ def _apply_module_schema_files( "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,), ) - applied_deltas = {d for d, in cur} + applied_deltas = {d for (d,) in cur} for name, stream in names_and_streams: if name in applied_deltas: continue @@ -710,7 +710,7 @@ def _get_or_create_schema_state( "SELECT file FROM applied_schema_deltas WHERE version >= ?", (current_version,), ) - applied_deltas = tuple(d for d, in txn) + applied_deltas = tuple(d for (d,) in txn) return _SchemaState( current_version=current_version, diff --git a/synapse/storage/schema/main/delta/56/unique_user_filter_index.py b/synapse/storage/schema/main/delta/56/unique_user_filter_index.py index 2461f87d77..b7535dae14 100644 --- a/synapse/storage/schema/main/delta/56/unique_user_filter_index.py +++ b/synapse/storage/schema/main/delta/56/unique_user_filter_index.py @@ -41,8 +41,6 @@ def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> (user_id, filter_id); DROP TABLE user_filters; ALTER TABLE user_filters_migration RENAME TO user_filters; - """ % ( - select_clause, - ) + """ % (select_clause,) execute_statements_from_stream(cur, StringIO(sql)) diff --git a/synapse/storage/schema/main/delta/61/03recreate_min_depth.py b/synapse/storage/schema/main/delta/61/03recreate_min_depth.py index 5d3578eaf4..a847ef4147 100644 --- a/synapse/storage/schema/main/delta/61/03recreate_min_depth.py +++ b/synapse/storage/schema/main/delta/61/03recreate_min_depth.py @@ -23,6 +23,7 @@ This migration handles the process of changing the type of `room_depth.min_depth` to a BIGINT. """ + from synapse.storage.database import LoggingTransaction from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine diff --git a/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py b/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py index b4d4b6536b..9ac3d1d31f 100644 --- a/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py +++ b/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py @@ -25,6 +25,7 @@ This migration adds triggers to the partial_state_events tables to enforce uniqu Triggers cannot be expressed in .sql files, so we have to use a separate file. """ + from synapse.storage.database import LoggingTransaction from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine diff --git a/synapse/storage/schema/main/delta/72/07force_update_current_state_events_membership.py b/synapse/storage/schema/main/delta/72/07force_update_current_state_events_membership.py index 93543fca7c..be80a6747d 100644 --- a/synapse/storage/schema/main/delta/72/07force_update_current_state_events_membership.py +++ b/synapse/storage/schema/main/delta/72/07force_update_current_state_events_membership.py @@ -26,6 +26,7 @@ for its completion can be removed. Note the background job must still remain defined in the database class. """ + from synapse.config.homeserver import HomeServerConfig from synapse.storage.database import LoggingTransaction from synapse.storage.engines import BaseDatabaseEngine diff --git a/synapse/storage/schema/main/delta/74/04_membership_tables_event_stream_ordering_triggers.py b/synapse/storage/schema/main/delta/74/04_membership_tables_event_stream_ordering_triggers.py index 6609ef0dac..a847a93494 100644 --- a/synapse/storage/schema/main/delta/74/04_membership_tables_event_stream_ordering_triggers.py +++ b/synapse/storage/schema/main/delta/74/04_membership_tables_event_stream_ordering_triggers.py @@ -24,6 +24,7 @@ This migration adds triggers to the room membership tables to enforce consistency. Triggers cannot be expressed in .sql files, so we have to use a separate file. """ + from synapse.storage.database import LoggingTransaction from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine diff --git a/synapse/storage/schema/main/delta/78/03event_extremities_constraints.py b/synapse/storage/schema/main/delta/78/03event_extremities_constraints.py index ad9c394162..1c823a3aa1 100644 --- a/synapse/storage/schema/main/delta/78/03event_extremities_constraints.py +++ b/synapse/storage/schema/main/delta/78/03event_extremities_constraints.py @@ -23,6 +23,7 @@ """ This migration adds foreign key constraint to `event_forward_extremities` table. """ + from synapse.storage.background_updates import ( ForeignKeyConstraint, run_validate_constraint_and_delete_rows_schema_delta, diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 5259550f1c..26783c5622 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -1308,7 +1308,7 @@ class DeviceListUpdates: def get_verify_key_from_cross_signing_key( - key_info: Mapping[str, Any] + key_info: Mapping[str, Any], ) -> Tuple[str, VerifyKey]: """Get the key ID and signedjson verify key from a cross-signing key dict diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py index 93b537ab7b..9f6fb087c1 100644 --- a/synapse/types/rest/client/__init__.py +++ b/synapse/types/rest/client/__init__.py @@ -268,7 +268,9 @@ class SlidingSyncBody(RequestBodyModel): if TYPE_CHECKING: ranges: Optional[List[Tuple[int, int]]] = None else: - ranges: Optional[List[Tuple[conint(ge=0, strict=True), conint(ge=0, strict=True)]]] = None # type: ignore[valid-type] + ranges: Optional[ + List[Tuple[conint(ge=0, strict=True), conint(ge=0, strict=True)]] + ] = None # type: ignore[valid-type] slow_get_all_rooms: Optional[StrictBool] = False filters: Optional[Filters] = None @@ -388,7 +390,9 @@ class SlidingSyncBody(RequestBodyModel): if TYPE_CHECKING: lists: Optional[Dict[str, SlidingSyncList]] = None else: - lists: Optional[Dict[constr(max_length=64, strict=True), SlidingSyncList]] = None # type: ignore[valid-type] + lists: Optional[Dict[constr(max_length=64, strict=True), SlidingSyncList]] = ( + None # type: ignore[valid-type] + ) room_subscriptions: Optional[Dict[StrictStr, RoomSubscription]] = None extensions: Optional[Extensions] = None diff --git a/synapse/types/state.py b/synapse/types/state.py index c958a95701..1141c4b5c1 100644 --- a/synapse/types/state.py +++ b/synapse/types/state.py @@ -503,13 +503,19 @@ class StateFilter: # - if so, which event types are excluded? ('excludes') # - which entire event types to include ('wildcards') # - which concrete state keys to include ('concrete state keys') - (self_all, self_excludes), ( - self_wildcards, - self_concrete_keys, + ( + (self_all, self_excludes), + ( + self_wildcards, + self_concrete_keys, + ), ) = self._decompose_into_four_parts() - (other_all, other_excludes), ( - other_wildcards, - other_concrete_keys, + ( + (other_all, other_excludes), + ( + other_wildcards, + other_concrete_keys, + ), ) = other._decompose_into_four_parts() # Start with an estimate of the difference based on self diff --git a/synapse/util/linked_list.py b/synapse/util/linked_list.py index e9a5fff211..87f801c0cf 100644 --- a/synapse/util/linked_list.py +++ b/synapse/util/linked_list.py @@ -19,8 +19,7 @@ # # -"""A circular doubly linked list implementation. -""" +"""A circular doubly linked list implementation.""" import threading from typing import Generic, Optional, Type, TypeVar diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 517e79ce5f..020618598c 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -110,7 +110,7 @@ def measure_func( """ def wrapper( - func: Callable[Concatenate[HasClock, P], Awaitable[R]] + func: Callable[Concatenate[HasClock, P], Awaitable[R]], ) -> Callable[P, Awaitable[R]]: block_name = func.__name__ if name is None else name diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index 46dad32156..56bdf451da 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -50,7 +50,7 @@ def do_patch() -> None: return def new_inline_callbacks( - f: Callable[P, Generator["Deferred[object]", object, T]] + f: Callable[P, Generator["Deferred[object]", object, T]], ) -> Callable[P, "Deferred[T]"]: @functools.wraps(f) def wrapped(*args: P.args, **kwargs: P.kwargs) -> "Deferred[T]": diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 8ead72bb7a..3f067b792c 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -103,7 +103,7 @@ _rate_limiter_instances_lock = threading.Lock() def _get_counts_from_rate_limiter_instance( - count_func: Callable[["FederationRateLimiter"], int] + count_func: Callable[["FederationRateLimiter"], int], ) -> Mapping[Tuple[str, ...], int]: """Returns a count of something (slept/rejected hosts) by (metrics_name)""" # Cast to a list to prevent it changing while the Prometheus diff --git a/synapse/visibility.py b/synapse/visibility.py index 128413c8aa..3a2782bade 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -135,9 +135,9 @@ async def filter_events_for_client( retention_policies: Dict[str, RetentionPolicy] = {} for room_id in room_ids: - retention_policies[room_id] = ( - await storage.main.get_retention_policy_for_room(room_id) - ) + retention_policies[ + room_id + ] = await storage.main.get_retention_policy_for_room(room_id) def allowed(event: EventBase) -> Optional[EventBase]: state_after_event = event_id_to_state.get(event.event_id) diff --git a/synmark/__main__.py b/synmark/__main__.py index cac57cf111..746261a1ec 100644 --- a/synmark/__main__.py +++ b/synmark/__main__.py @@ -40,7 +40,7 @@ T = TypeVar("T") def make_test( - main: Callable[[ISynapseReactor, int], Coroutine[Any, Any, float]] + main: Callable[[ISynapseReactor, int], Coroutine[Any, Any, float]], ) -> Callable[[int], float]: """ Take a benchmark function and wrap it in a reactor start and stop. diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index a1c7ccdd0b..730b00a9fb 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -150,7 +150,8 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.assertEqual(1, len(self.txnctrl.recoverers)) # and stored self.assertEqual(0, txn.complete.call_count) # txn not completed self.store.set_appservice_state.assert_called_once_with( - service, ApplicationServiceState.DOWN # service marked as down + service, + ApplicationServiceState.DOWN, # service marked as down ) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 30f8787758..654e6521a2 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -756,7 +756,8 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): def test_event_fields_fail_if_fields_not_str(self) -> None: with self.assertRaises(TypeError): self.serialize( - MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4] # type: ignore[list-item] + MockEvent(room_id="!foo:bar", content={"foo": "bar"}), + ["room_id", 4], # type: ignore[list-item] ) diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 9bd97e5d4e..87b9ffc0c6 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -158,7 +158,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): async def get_current_state_event_counts(room_id: str) -> int: return 600 - self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts # type: ignore[method-assign] + self.hs.get_datastores().main.get_current_state_event_counts = ( # type: ignore[method-assign] + get_current_state_event_counts + ) d = handler._remote_join( create_requester(u1), diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index 08214b0013..1e1ed8e642 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -401,7 +401,10 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): now = self.clock.time_msec() self.get_success( self.hs.get_datastores().main.set_destination_retry_timings( - "zzzerver", now, now, 24 * 60 * 60 * 1000 # retry in 1 day + "zzzerver", + now, + now, + 24 * 60 * 60 * 1000, # retry in 1 day ) ) diff --git a/tests/federation/test_federation_media.py b/tests/federation/test_federation_media.py index 0dcf20f5f5..e66aae499b 100644 --- a/tests/federation/test_federation_media.py +++ b/tests/federation/test_federation_media.py @@ -40,7 +40,6 @@ from tests.test_utils import SMALL_PNG class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase): - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: super().prepare(reactor, clock, hs) self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-") @@ -150,7 +149,6 @@ class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase): class FederationThumbnailTest(unittest.FederatingHomeserverTestCase): - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: super().prepare(reactor, clock, hs) self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-") diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 1b83aea579..5db10fa74c 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -288,13 +288,15 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): } # We also expect an outbound request to /state - self.mock_federation_transport_client.get_room_state.return_value = StateRequestResponse( - # Mimic the other server not knowing about the state at all. - # We want to cause Synapse to throw an error (`Unable to get - # missing prev_event $fake_prev_event`) and fail to backfill - # the pulled event. - auth_events=[], - state=[], + self.mock_federation_transport_client.get_room_state.return_value = ( + StateRequestResponse( + # Mimic the other server not knowing about the state at all. + # We want to cause Synapse to throw an error (`Unable to get + # missing prev_event $fake_prev_event`) and fail to backfill + # the pulled event. + auth_events=[], + state=[], + ) ) pulled_event = make_event_from_dict( diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index cc630d606c..598d6c13cd 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -1107,7 +1107,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): ), ] ], - name_func=lambda testcase_func, param_num, params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[5] else 'monolith'}", + name_func=lambda testcase_func, + param_num, + params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[5] else 'monolith'}", ) @unittest.override_config({"experimental_features": {"msc3026_enabled": True}}) def test_set_presence_from_syncing_multi_device( @@ -1343,7 +1345,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): ), ] ], - name_func=lambda testcase_func, param_num, params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[4] else 'monolith'}", + name_func=lambda testcase_func, + param_num, + params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[4] else 'monolith'}", ) @unittest.override_config({"experimental_features": {"msc3026_enabled": True}}) def test_set_presence_from_non_syncing_multi_device( diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index fa55f76916..d7bbc68037 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -843,7 +843,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): ) -> List[EventBase]: return list(pdus) - self.client._check_sigs_and_hash_for_pulled_events_and_fetch = _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment] + self.client._check_sigs_and_hash_for_pulled_events_and_fetch = ( # type: ignore[method-assign] + _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment] + ) prev_events = self.get_success(self.store.get_prev_events_for_room(room_id)) diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index 8e8621e348..ffcbf4b3ca 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -93,9 +93,7 @@ class SrvResolverTestCase(unittest.TestCase): resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) servers: List[Server] - servers = yield defer.ensureDeferred( - resolver.resolve_service(service_name) - ) # type: ignore[assignment] + servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) # type: ignore[assignment] dns_client_mock.lookupService.assert_called_once_with(service_name) @@ -122,9 +120,7 @@ class SrvResolverTestCase(unittest.TestCase): ) servers: List[Server] - servers = yield defer.ensureDeferred( - resolver.resolve_service(service_name) - ) # type: ignore[assignment] + servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) # type: ignore[assignment] self.assertFalse(dns_client_mock.lookupService.called) @@ -157,9 +153,7 @@ class SrvResolverTestCase(unittest.TestCase): resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) servers: List[Server] - servers = yield defer.ensureDeferred( - resolver.resolve_service(service_name) - ) # type: ignore[assignment] + servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) # type: ignore[assignment] self.assertEqual(len(servers), 0) self.assertEqual(len(cache), 0) diff --git a/tests/http/test_client.py b/tests/http/test_client.py index f2abec190b..ac6470ebbd 100644 --- a/tests/http/test_client.py +++ b/tests/http/test_client.py @@ -207,7 +207,9 @@ class ReadMultipartResponseTests(TestCase): class ReadBodyWithMaxSizeTests(TestCase): - def _build_response(self, length: Union[int, str] = UNKNOWN_LENGTH) -> Tuple[ + def _build_response( + self, length: Union[int, str] = UNKNOWN_LENGTH + ) -> Tuple[ BytesIO, "Deferred[int]", _DiscardBodyWithMaxSizeProtocol, diff --git a/tests/http/test_matrixfederationclient.py b/tests/http/test_matrixfederationclient.py index 6827412373..6588695e37 100644 --- a/tests/http/test_matrixfederationclient.py +++ b/tests/http/test_matrixfederationclient.py @@ -895,21 +895,23 @@ class FederationClientProxyTests(BaseMultiWorkerStreamTestCase): ) # Fake `remoteserv:8008` responding to requests - mock_agent_on_federation_sender.request.side_effect = lambda *args, **kwargs: defer.succeed( - FakeResponse( - code=200, - body=b'{"foo": "bar"}', - headers=Headers( - { - "Content-Type": ["application/json"], - "Connection": ["close, X-Foo, X-Bar"], - # Should be removed because it's defined in the `Connection` header - "X-Foo": ["foo"], - "X-Bar": ["bar"], - # Should be removed because it's a hop-by-hop header - "Proxy-Authorization": "abcdef", - } - ), + mock_agent_on_federation_sender.request.side_effect = ( + lambda *args, **kwargs: defer.succeed( + FakeResponse( + code=200, + body=b'{"foo": "bar"}', + headers=Headers( + { + "Content-Type": ["application/json"], + "Connection": ["close, X-Foo, X-Bar"], + # Should be removed because it's defined in the `Connection` header + "X-Foo": ["foo"], + "X-Bar": ["bar"], + # Should be removed because it's a hop-by-hop header + "Proxy-Authorization": "abcdef", + } + ), + ) ) ) diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py index 18af2735fe..db39ecf244 100644 --- a/tests/http/test_servlet.py +++ b/tests/http/test_servlet.py @@ -76,7 +76,7 @@ class TestServletUtils(unittest.TestCase): # Invalid UTF-8. with self.assertRaises(SynapseError): - parse_json_value_from_request(make_request(b"\xFF\x00")) + parse_json_value_from_request(make_request(b"\xff\x00")) # Invalid JSON. with self.assertRaises(SynapseError): diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py index e55001fb40..e50ff5fa78 100644 --- a/tests/media/test_media_storage.py +++ b/tests/media/test_media_storage.py @@ -261,7 +261,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): """A mock for MatrixFederationHttpClient.get_file.""" def write_to( - r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]] + r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]], ) -> Tuple[int, Dict[bytes, List[bytes]]]: data, response = r output_stream.write(data) diff --git a/tests/module_api/test_account_data_manager.py b/tests/module_api/test_account_data_manager.py index fd87eaffd0..1a1d5609b2 100644 --- a/tests/module_api/test_account_data_manager.py +++ b/tests/module_api/test_account_data_manager.py @@ -164,6 +164,8 @@ class ModuleApiTestCase(HomeserverTestCase): # noinspection PyTypeChecker self.get_success_or_raise( self._module_api.account_data_manager.put_global( - self.user_id, "test.data", 42 # type: ignore[arg-type] + self.user_id, + "test.data", + 42, # type: ignore[arg-type] ) ) diff --git a/tests/push/test_email.py b/tests/push/test_email.py index e0aab1c046..4fafb71897 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -44,6 +44,7 @@ from tests.unittest import HomeserverTestCase @attr.s(auto_attribs=True) class _User: "Helper wrapper for user ID and access token" + id: str token: str diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index 2a1e42bbc8..150caeeee2 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -531,9 +531,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): # simulate a change in server config after a server restart. new_display_name = "new display name" - self.server_notices_manager._config.servernotices.server_notices_mxid_display_name = ( - new_display_name - ) + self.server_notices_manager._config.servernotices.server_notices_mxid_display_name = new_display_name self.server_notices_manager.get_or_create_notice_room_for_user.cache.invalidate_all() self.make_request( @@ -577,9 +575,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): # simulate a change in server config after a server restart. new_avatar_url = "test/new-url" - self.server_notices_manager._config.servernotices.server_notices_mxid_avatar_url = ( - new_avatar_url - ) + self.server_notices_manager._config.servernotices.server_notices_mxid_avatar_url = new_avatar_url self.server_notices_manager.get_or_create_notice_room_for_user.cache.invalidate_all() self.make_request( @@ -692,9 +688,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): # simulate a change in server config after a server restart. new_avatar_url = "test/new-url" - self.server_notices_manager._config.servernotices.server_notices_room_avatar_url = ( - new_avatar_url - ) + self.server_notices_manager._config.servernotices.server_notices_room_avatar_url = new_avatar_url self.server_notices_manager.get_or_create_notice_room_for_user.cache.invalidate_all() self.make_request( diff --git a/tests/rest/client/sliding_sync/test_connection_tracking.py b/tests/rest/client/sliding_sync/test_connection_tracking.py index 436bd4466c..5b819103c2 100644 --- a/tests/rest/client/sliding_sync/test_connection_tracking.py +++ b/tests/rest/client/sliding_sync/test_connection_tracking.py @@ -38,7 +38,9 @@ logger = logging.getLogger(__name__) (True,), (False,), ], - class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", + class_name_func=lambda cls, + num, + params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", ) class SlidingSyncConnectionTrackingTestCase(SlidingSyncBase): """ diff --git a/tests/rest/client/sliding_sync/test_extension_account_data.py b/tests/rest/client/sliding_sync/test_extension_account_data.py index 65a6adf4af..6cc883a4be 100644 --- a/tests/rest/client/sliding_sync/test_extension_account_data.py +++ b/tests/rest/client/sliding_sync/test_extension_account_data.py @@ -40,7 +40,9 @@ logger = logging.getLogger(__name__) (True,), (False,), ], - class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", + class_name_func=lambda cls, + num, + params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", ) class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase): """Tests for the account_data sliding sync extension""" diff --git a/tests/rest/client/sliding_sync/test_extension_e2ee.py b/tests/rest/client/sliding_sync/test_extension_e2ee.py index 2ff6687796..7ce6592d8f 100644 --- a/tests/rest/client/sliding_sync/test_extension_e2ee.py +++ b/tests/rest/client/sliding_sync/test_extension_e2ee.py @@ -39,7 +39,9 @@ logger = logging.getLogger(__name__) (True,), (False,), ], - class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", + class_name_func=lambda cls, + num, + params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", ) class SlidingSyncE2eeExtensionTestCase(SlidingSyncBase): """Tests for the e2ee sliding sync extension""" diff --git a/tests/rest/client/sliding_sync/test_extension_receipts.py b/tests/rest/client/sliding_sync/test_extension_receipts.py index 90b035dd75..6e7700b533 100644 --- a/tests/rest/client/sliding_sync/test_extension_receipts.py +++ b/tests/rest/client/sliding_sync/test_extension_receipts.py @@ -40,7 +40,9 @@ logger = logging.getLogger(__name__) (True,), (False,), ], - class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", + class_name_func=lambda cls, + num, + params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", ) class SlidingSyncReceiptsExtensionTestCase(SlidingSyncBase): """Tests for the receipts sliding sync extension""" diff --git a/tests/rest/client/sliding_sync/test_extension_to_device.py b/tests/rest/client/sliding_sync/test_extension_to_device.py index 5ba2443089..790abb739d 100644 --- a/tests/rest/client/sliding_sync/test_extension_to_device.py +++ b/tests/rest/client/sliding_sync/test_extension_to_device.py @@ -40,7 +40,9 @@ logger = logging.getLogger(__name__) (True,), (False,), ], - class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", + class_name_func=lambda cls, + num, + params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", ) class SlidingSyncToDeviceExtensionTestCase(SlidingSyncBase): """Tests for the to-device sliding sync extension""" diff --git a/tests/rest/client/sliding_sync/test_extension_typing.py b/tests/rest/client/sliding_sync/test_extension_typing.py index 0a0f5aff1a..f87c3c8b17 100644 --- a/tests/rest/client/sliding_sync/test_extension_typing.py +++ b/tests/rest/client/sliding_sync/test_extension_typing.py @@ -40,7 +40,9 @@ logger = logging.getLogger(__name__) (True,), (False,), ], - class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", + class_name_func=lambda cls, + num, + params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", ) class SlidingSyncTypingExtensionTestCase(SlidingSyncBase): """Tests for the typing notification sliding sync extension""" diff --git a/tests/rest/client/sliding_sync/test_extensions.py b/tests/rest/client/sliding_sync/test_extensions.py index 32478467aa..30230e5c4b 100644 --- a/tests/rest/client/sliding_sync/test_extensions.py +++ b/tests/rest/client/sliding_sync/test_extensions.py @@ -40,7 +40,9 @@ logger = logging.getLogger(__name__) (True,), (False,), ], - class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", + class_name_func=lambda cls, + num, + params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", ) class SlidingSyncExtensionsTestCase(SlidingSyncBase): """ diff --git a/tests/rest/client/sliding_sync/test_room_subscriptions.py b/tests/rest/client/sliding_sync/test_room_subscriptions.py index e81d251839..285fdaaf78 100644 --- a/tests/rest/client/sliding_sync/test_room_subscriptions.py +++ b/tests/rest/client/sliding_sync/test_room_subscriptions.py @@ -39,7 +39,9 @@ logger = logging.getLogger(__name__) (True,), (False,), ], - class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", + class_name_func=lambda cls, + num, + params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", ) class SlidingSyncRoomSubscriptionsTestCase(SlidingSyncBase): """ diff --git a/tests/rest/client/sliding_sync/test_rooms_invites.py b/tests/rest/client/sliding_sync/test_rooms_invites.py index f6f45c2500..882762ca29 100644 --- a/tests/rest/client/sliding_sync/test_rooms_invites.py +++ b/tests/rest/client/sliding_sync/test_rooms_invites.py @@ -39,7 +39,9 @@ logger = logging.getLogger(__name__) (True,), (False,), ], - class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", + class_name_func=lambda cls, + num, + params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", ) class SlidingSyncRoomsInvitesTestCase(SlidingSyncBase): """ diff --git a/tests/rest/client/sliding_sync/test_rooms_meta.py b/tests/rest/client/sliding_sync/test_rooms_meta.py index 71542923da..4ed49040c1 100644 --- a/tests/rest/client/sliding_sync/test_rooms_meta.py +++ b/tests/rest/client/sliding_sync/test_rooms_meta.py @@ -40,7 +40,9 @@ logger = logging.getLogger(__name__) (True,), (False,), ], - class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", + class_name_func=lambda cls, + num, + params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", ) class SlidingSyncRoomsMetaTestCase(SlidingSyncBase): """ diff --git a/tests/rest/client/sliding_sync/test_rooms_required_state.py b/tests/rest/client/sliding_sync/test_rooms_required_state.py index 436ae684da..91ac6c5a0e 100644 --- a/tests/rest/client/sliding_sync/test_rooms_required_state.py +++ b/tests/rest/client/sliding_sync/test_rooms_required_state.py @@ -40,7 +40,9 @@ logger = logging.getLogger(__name__) (True,), (False,), ], - class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", + class_name_func=lambda cls, + num, + params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", ) class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase): """ diff --git a/tests/rest/client/sliding_sync/test_rooms_timeline.py b/tests/rest/client/sliding_sync/test_rooms_timeline.py index e56fb58012..0e027ff39d 100644 --- a/tests/rest/client/sliding_sync/test_rooms_timeline.py +++ b/tests/rest/client/sliding_sync/test_rooms_timeline.py @@ -40,7 +40,9 @@ logger = logging.getLogger(__name__) (True,), (False,), ], - class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", + class_name_func=lambda cls, + num, + params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", ) class SlidingSyncRoomsTimelineTestCase(SlidingSyncBase): """ diff --git a/tests/rest/client/sliding_sync/test_sliding_sync.py b/tests/rest/client/sliding_sync/test_sliding_sync.py index 1dcc15b082..930cb5ef45 100644 --- a/tests/rest/client/sliding_sync/test_sliding_sync.py +++ b/tests/rest/client/sliding_sync/test_sliding_sync.py @@ -232,7 +232,9 @@ class SlidingSyncBase(unittest.HomeserverTestCase): (True,), (False,), ], - class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", + class_name_func=lambda cls, + num, + params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", ) class SlidingSyncTestCase(SlidingSyncBase): """ diff --git a/tests/rest/client/test_auth_issuer.py b/tests/rest/client/test_auth_issuer.py index 299475a35c..d6f334a7ab 100644 --- a/tests/rest/client/test_auth_issuer.py +++ b/tests/rest/client/test_auth_issuer.py @@ -63,7 +63,9 @@ class AuthIssuerTestCase(HomeserverTestCase): self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(channel.json_body, {"issuer": ISSUER}) - req_mock.assert_called_with("https://account.example.com/.well-known/openid-configuration") + req_mock.assert_called_with( + "https://account.example.com/.well-known/openid-configuration" + ) req_mock.reset_mock() # Second call it should use the cached value diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py index 06f1c1b234..039144fdbe 100644 --- a/tests/rest/client/test_events.py +++ b/tests/rest/client/test_events.py @@ -19,7 +19,7 @@ # # -""" Tests REST events for /events paths.""" +"""Tests REST events for /events paths.""" from unittest.mock import Mock diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index 30b6d31d0a..42014e257e 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -1957,7 +1957,7 @@ class DownloadAndThumbnailTestCase(unittest.HomeserverTestCase): """A mock for MatrixFederationHttpClient.federation_get_file.""" def write_to( - r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]] + r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]], ) -> Tuple[int, Dict[bytes, List[bytes]], bytes]: data, response = r output_stream.write(data) @@ -1991,7 +1991,7 @@ class DownloadAndThumbnailTestCase(unittest.HomeserverTestCase): """A mock for MatrixFederationHttpClient.get_file.""" def write_to( - r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]] + r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]], ) -> Tuple[int, Dict[bytes, List[bytes]]]: data, response = r output_stream.write(data) diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index f98f3f77aa..a92713d220 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -20,6 +20,7 @@ # """Tests REST events for /profile paths.""" + import urllib.parse from http import HTTPStatus from typing import Any, Dict, Optional diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 694f143eff..c091f403cc 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -1049,9 +1049,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # Check that the HTML we're getting is the one we expect when using an # invalid/unknown token. - expected_html = ( - self.hs.config.account_validity.account_validity_invalid_token_template.render() - ) + expected_html = self.hs.config.account_validity.account_validity_invalid_token_template.render() self.assertEqual( channel.result["body"], expected_html.encode("utf8"), channel.result ) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 9614cdd66a..a1c284726a 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -330,22 +330,24 @@ class RestHelper: data, ) - assert ( - channel.code == expect_code - ), "Expected: %d, got: %d, PUT %s -> resp: %r" % ( - expect_code, - channel.code, - path, - channel.result["body"], + assert channel.code == expect_code, ( + "Expected: %d, got: %d, PUT %s -> resp: %r" + % ( + expect_code, + channel.code, + path, + channel.result["body"], + ) ) if expect_errcode: - assert ( - str(channel.json_body["errcode"]) == expect_errcode - ), "Expected: %r, got: %r, resp: %r" % ( - expect_errcode, - channel.json_body["errcode"], - channel.result["body"], + assert str(channel.json_body["errcode"]) == expect_errcode, ( + "Expected: %r, got: %r, resp: %r" + % ( + expect_errcode, + channel.json_body["errcode"], + channel.result["body"], + ) ) if expect_additional_fields is not None: @@ -354,13 +356,14 @@ class RestHelper: expect_key, channel.json_body, ) - assert ( - channel.json_body[expect_key] == expect_value - ), "Expected: %s at %s, got: %s, resp: %s" % ( - expect_value, - expect_key, - channel.json_body[expect_key], - channel.json_body, + assert channel.json_body[expect_key] == expect_value, ( + "Expected: %s at %s, got: %s, resp: %s" + % ( + expect_value, + expect_key, + channel.json_body[expect_key], + channel.json_body, + ) ) self.auth_user_id = temp_id diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index ac992766e8..c73717f014 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -124,7 +124,12 @@ class WellKnownTests(unittest.HomeserverTestCase): ) def test_client_well_known_msc3861_oauth_delegation(self) -> None: # Patch the HTTP client to return the issuer metadata - req_mock = AsyncMock(return_value={"issuer": "https://issuer", "account_management_uri": "https://my-account.issuer"}) + req_mock = AsyncMock( + return_value={ + "issuer": "https://issuer", + "account_management_uri": "https://my-account.issuer", + } + ) self.hs.get_proxied_http_client().get_json = req_mock # type: ignore[method-assign] for _ in range(2): @@ -145,4 +150,6 @@ class WellKnownTests(unittest.HomeserverTestCase): ) # It should have been called exactly once, because it gets cached - req_mock.assert_called_once_with("https://issuer/.well-known/openid-configuration") + req_mock.assert_called_once_with( + "https://issuer/.well-known/openid-configuration" + ) diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 6d1ae4c8d7..f12402f5f2 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -292,12 +292,14 @@ class EventAuthTestCase(unittest.TestCase): ] # pleb should not be able to send state - self.assertRaises( - AuthError, - event_auth.check_state_dependent_auth_rules, - _random_state_event(RoomVersions.V1, pleb), - auth_events, - ), + ( + self.assertRaises( + AuthError, + event_auth.check_state_dependent_auth_rules, + _random_state_event(RoomVersions.V1, pleb), + auth_events, + ), + ) # king should be able to send state event_auth.check_state_dependent_auth_rules( diff --git a/tests/test_federation.py b/tests/test_federation.py index 4e9adc0625..94b0fa9856 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -101,7 +101,9 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) -> List[EventBase]: return list(pdus) - self.client._check_sigs_and_hash_for_pulled_events_and_fetch = _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment] + self.client._check_sigs_and_hash_for_pulled_events_and_fetch = ( # type: ignore[method-assign] + _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment] + ) # Send the join, it should return None (which is not an error) self.assertEqual( diff --git a/tests/test_types.py b/tests/test_types.py index 00adc65a5a..0c08bc8ecc 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -145,7 +145,9 @@ class MapUsernameTestCase(unittest.TestCase): (MultiWriterStreamToken,), (RoomStreamToken,), ], - class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{params_dict['token_type'].__name__}", + class_name_func=lambda cls, + num, + params_dict: f"{cls.__name__}_{params_dict['token_type'].__name__}", ) class MultiWriterTokenTestCase(unittest.HomeserverTestCase): """Tests for the different types of multi writer tokens.""" diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 4ab42a02b9..4d7adf7204 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -22,6 +22,7 @@ """ Utilities for running the unit tests """ + import json import sys import warnings diff --git a/tests/unittest.py b/tests/unittest.py index 2532fa49fb..614e805abd 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -457,7 +457,9 @@ class HomeserverTestCase(TestCase): # Type ignore: mypy doesn't like us assigning to methods. self.hs.get_auth().get_user_by_req = get_requester # type: ignore[method-assign] self.hs.get_auth().get_user_by_access_token = get_requester # type: ignore[method-assign] - self.hs.get_auth().get_access_token_from_request = Mock(return_value=token) # type: ignore[method-assign] + self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[method-assign] + return_value=token + ) if self.needs_threadpool: self.reactor.threadpool = ThreadPool() # type: ignore[assignment]