mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-25 02:55:46 +03:00
Require type hints in the handlers module. (#10831)
Adds missing type hints to methods in the synapse.handlers module and requires all methods to have type hints there. This also removes the unused construct_auth_difference method from the FederationHandler.
This commit is contained in:
parent
437961744c
commit
b3590614da
35 changed files with 194 additions and 295 deletions
1
changelog.d/10831.misc
Normal file
1
changelog.d/10831.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add missing type hints to handlers.
|
3
mypy.ini
3
mypy.ini
|
@ -91,6 +91,9 @@ files =
|
|||
tests/util/test_itertools.py,
|
||||
tests/util/test_stream_change_cache.py
|
||||
|
||||
[mypy-synapse.handlers.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.rest.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Tuple, Type
|
||||
|
||||
from synapse.util.module_loader import load_module
|
||||
|
||||
|
@ -25,7 +25,7 @@ class PasswordAuthProviderConfig(Config):
|
|||
section = "authproviders"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
self.password_providers: List[Any] = []
|
||||
self.password_providers: List[Tuple[Type, Any]] = []
|
||||
providers = []
|
||||
|
||||
# We want to be backwards compatible with the old `ldap_config`
|
||||
|
|
|
@ -16,6 +16,7 @@ import logging
|
|||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.types import Requester
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
@ -63,16 +64,21 @@ class BaseHandler:
|
|||
|
||||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
|
||||
async def ratelimit(self, requester, update=True, is_admin_redaction=False):
|
||||
async def ratelimit(
|
||||
self,
|
||||
requester: Requester,
|
||||
update: bool = True,
|
||||
is_admin_redaction: bool = False,
|
||||
) -> None:
|
||||
"""Ratelimits requests.
|
||||
|
||||
Args:
|
||||
requester (Requester)
|
||||
update (bool): Whether to record that a request is being processed.
|
||||
requester
|
||||
update: Whether to record that a request is being processed.
|
||||
Set to False when doing multiple checks for one request (e.g.
|
||||
to check up front if we would reject the request), and set to
|
||||
True for the last call for a given request.
|
||||
is_admin_redaction (bool): Whether this is a room admin/moderator
|
||||
is_admin_redaction: Whether this is a room admin/moderator
|
||||
redacting an event. If so then we may apply different
|
||||
ratelimits depending on config.
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
from typing import TYPE_CHECKING, List, Tuple
|
||||
from typing import TYPE_CHECKING, Any, List, Tuple
|
||||
|
||||
from synapse.replication.http.account_data import (
|
||||
ReplicationAddTagRestServlet,
|
||||
|
@ -171,7 +171,7 @@ class AccountDataEventSource:
|
|||
return self.store.get_max_account_data_stream_id()
|
||||
|
||||
async def get_new_events(
|
||||
self, user: UserID, from_key: int, **kwargs
|
||||
self, user: UserID, from_key: int, **kwargs: Any
|
||||
) -> Tuple[List[JsonDict], int]:
|
||||
user_id = user.to_string()
|
||||
last_stream_id = from_key
|
||||
|
|
|
@ -99,7 +99,7 @@ class AccountValidityHandler:
|
|||
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
|
||||
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
|
||||
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Register callbacks from module for each hook."""
|
||||
if is_user_expired is not None:
|
||||
self._is_user_expired_callbacks.append(is_user_expired)
|
||||
|
@ -165,7 +165,7 @@ class AccountValidityHandler:
|
|||
|
||||
return False
|
||||
|
||||
async def on_user_registration(self, user_id: str):
|
||||
async def on_user_registration(self, user_id: str) -> None:
|
||||
"""Tell third-party modules about a user's registration.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Union
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
|
@ -58,7 +58,7 @@ class ApplicationServicesHandler:
|
|||
self.current_max = 0
|
||||
self.is_processing = False
|
||||
|
||||
def notify_interested_services(self, max_token: RoomStreamToken):
|
||||
def notify_interested_services(self, max_token: RoomStreamToken) -> None:
|
||||
"""Notifies (pushes) all application services interested in this event.
|
||||
|
||||
Pushing is done asynchronously, so this method won't block for any
|
||||
|
@ -82,7 +82,7 @@ class ApplicationServicesHandler:
|
|||
self._notify_interested_services(max_token)
|
||||
|
||||
@wrap_as_background_process("notify_interested_services")
|
||||
async def _notify_interested_services(self, max_token: RoomStreamToken):
|
||||
async def _notify_interested_services(self, max_token: RoomStreamToken) -> None:
|
||||
with Measure(self.clock, "notify_interested_services"):
|
||||
self.is_processing = True
|
||||
try:
|
||||
|
@ -100,7 +100,7 @@ class ApplicationServicesHandler:
|
|||
for event in events:
|
||||
events_by_room.setdefault(event.room_id, []).append(event)
|
||||
|
||||
async def handle_event(event):
|
||||
async def handle_event(event: EventBase) -> None:
|
||||
# Gather interested services
|
||||
services = await self._get_services_for_event(event)
|
||||
if len(services) == 0:
|
||||
|
@ -116,9 +116,9 @@ class ApplicationServicesHandler:
|
|||
|
||||
if not self.started_scheduler:
|
||||
|
||||
async def start_scheduler():
|
||||
async def start_scheduler() -> None:
|
||||
try:
|
||||
return await self.scheduler.start()
|
||||
await self.scheduler.start()
|
||||
except Exception:
|
||||
logger.error("Application Services Failure")
|
||||
|
||||
|
@ -137,7 +137,7 @@ class ApplicationServicesHandler:
|
|||
"appservice_sender"
|
||||
).observe((now - ts) / 1000)
|
||||
|
||||
async def handle_room_events(events):
|
||||
async def handle_room_events(events: Iterable[EventBase]) -> None:
|
||||
for event in events:
|
||||
await handle_event(event)
|
||||
|
||||
|
@ -184,7 +184,7 @@ class ApplicationServicesHandler:
|
|||
stream_key: str,
|
||||
new_token: Optional[int],
|
||||
users: Optional[Collection[Union[str, UserID]]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""This is called by the notifier in the background
|
||||
when a ephemeral event handled by the homeserver.
|
||||
|
||||
|
@ -226,7 +226,7 @@ class ApplicationServicesHandler:
|
|||
stream_key: str,
|
||||
new_token: Optional[int],
|
||||
users: Collection[Union[str, UserID]],
|
||||
):
|
||||
) -> None:
|
||||
logger.debug("Checking interested services for %s" % (stream_key))
|
||||
with Measure(self.clock, "notify_interested_services_ephemeral"):
|
||||
for service in services:
|
||||
|
|
|
@ -29,6 +29,7 @@ from typing import (
|
|||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
@ -439,7 +440,7 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
return ui_auth_types
|
||||
|
||||
def get_enabled_auth_types(self):
|
||||
def get_enabled_auth_types(self) -> Iterable[str]:
|
||||
"""Return the enabled user-interactive authentication types
|
||||
|
||||
Returns the UI-Auth types which are supported by the homeserver's current
|
||||
|
@ -702,7 +703,7 @@ class AuthHandler(BaseHandler):
|
|||
except StoreError:
|
||||
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
|
||||
|
||||
async def _expire_old_sessions(self):
|
||||
async def _expire_old_sessions(self) -> None:
|
||||
"""
|
||||
Invalidate any user interactive authentication sessions that have expired.
|
||||
"""
|
||||
|
@ -1352,7 +1353,7 @@ class AuthHandler(BaseHandler):
|
|||
await self.auth.check_auth_blocking(res.user_id)
|
||||
return res
|
||||
|
||||
async def delete_access_token(self, access_token: str):
|
||||
async def delete_access_token(self, access_token: str) -> None:
|
||||
"""Invalidate a single access token
|
||||
|
||||
Args:
|
||||
|
@ -1381,7 +1382,7 @@ class AuthHandler(BaseHandler):
|
|||
user_id: str,
|
||||
except_token_id: Optional[int] = None,
|
||||
device_id: Optional[str] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Invalidate access tokens belonging to a user
|
||||
|
||||
Args:
|
||||
|
@ -1409,7 +1410,7 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
async def add_threepid(
|
||||
self, user_id: str, medium: str, address: str, validated_at: int
|
||||
):
|
||||
) -> None:
|
||||
# check if medium has a valid value
|
||||
if medium not in ["email", "msisdn"]:
|
||||
raise SynapseError(
|
||||
|
@ -1480,7 +1481,7 @@ class AuthHandler(BaseHandler):
|
|||
Hashed password.
|
||||
"""
|
||||
|
||||
def _do_hash():
|
||||
def _do_hash() -> str:
|
||||
# Normalise the Unicode in the password
|
||||
pw = unicodedata.normalize("NFKC", password)
|
||||
|
||||
|
@ -1504,7 +1505,7 @@ class AuthHandler(BaseHandler):
|
|||
Whether self.hash(password) == stored_hash.
|
||||
"""
|
||||
|
||||
def _do_validate_hash(checked_hash: bytes):
|
||||
def _do_validate_hash(checked_hash: bytes) -> bool:
|
||||
# Normalise the Unicode in the password
|
||||
pw = unicodedata.normalize("NFKC", password)
|
||||
|
||||
|
@ -1581,7 +1582,7 @@ class AuthHandler(BaseHandler):
|
|||
client_redirect_url: str,
|
||||
extra_attributes: Optional[JsonDict] = None,
|
||||
new_user: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
"""Having figured out a mxid for this user, complete the HTTP request
|
||||
|
||||
Args:
|
||||
|
@ -1627,7 +1628,7 @@ class AuthHandler(BaseHandler):
|
|||
extra_attributes: Optional[JsonDict] = None,
|
||||
new_user: bool = False,
|
||||
user_profile_data: Optional[ProfileInfo] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
The synchronous portion of complete_sso_login.
|
||||
|
||||
|
@ -1726,7 +1727,7 @@ class AuthHandler(BaseHandler):
|
|||
del self._extra_attributes[user_id]
|
||||
|
||||
@staticmethod
|
||||
def add_query_param_to_url(url: str, param_name: str, param: Any):
|
||||
def add_query_param_to_url(url: str, param_name: str, param: Any) -> str:
|
||||
url_parts = list(urllib.parse.urlparse(url))
|
||||
query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
|
||||
query.append((param_name, param))
|
||||
|
@ -1734,9 +1735,9 @@ class AuthHandler(BaseHandler):
|
|||
return urllib.parse.urlunparse(url_parts)
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class MacaroonGenerator:
|
||||
hs = attr.ib()
|
||||
hs: "HomeServer"
|
||||
|
||||
def generate_guest_access_token(self, user_id: str) -> str:
|
||||
macaroon = self._generate_base_macaroon(user_id)
|
||||
|
@ -1816,7 +1817,9 @@ class PasswordProvider:
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
|
||||
def load(
|
||||
cls, module: Type, config: JsonDict, module_api: ModuleApi
|
||||
) -> "PasswordProvider":
|
||||
try:
|
||||
pp = module(config=config, account_handler=module_api)
|
||||
except Exception as e:
|
||||
|
@ -1824,7 +1827,7 @@ class PasswordProvider:
|
|||
raise
|
||||
return cls(pp, module_api)
|
||||
|
||||
def __init__(self, pp, module_api: ModuleApi):
|
||||
def __init__(self, pp: "PasswordProvider", module_api: ModuleApi):
|
||||
self._pp = pp
|
||||
self._module_api = module_api
|
||||
|
||||
|
@ -1838,7 +1841,7 @@ class PasswordProvider:
|
|||
if g:
|
||||
self._supported_login_types.update(g())
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return str(self._pp)
|
||||
|
||||
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
|
||||
|
@ -1876,19 +1879,19 @@ class PasswordProvider:
|
|||
"""
|
||||
# first grandfather in a call to check_password
|
||||
if login_type == LoginType.PASSWORD:
|
||||
g = getattr(self._pp, "check_password", None)
|
||||
if g:
|
||||
check_password = getattr(self._pp, "check_password", None)
|
||||
if check_password:
|
||||
qualified_user_id = self._module_api.get_qualified_user_id(username)
|
||||
is_valid = await self._pp.check_password(
|
||||
is_valid = await check_password(
|
||||
qualified_user_id, login_dict["password"]
|
||||
)
|
||||
if is_valid:
|
||||
return qualified_user_id, None
|
||||
|
||||
g = getattr(self._pp, "check_auth", None)
|
||||
if not g:
|
||||
check_auth = getattr(self._pp, "check_auth", None)
|
||||
if not check_auth:
|
||||
return None
|
||||
result = await g(username, login_type, login_dict)
|
||||
result = await check_auth(username, login_type, login_dict)
|
||||
|
||||
# Check if the return value is a str or a tuple
|
||||
if isinstance(result, str):
|
||||
|
|
|
@ -34,20 +34,20 @@ logger = logging.getLogger(__name__)
|
|||
class CasError(Exception):
|
||||
"""Used to catch errors when validating the CAS ticket."""
|
||||
|
||||
def __init__(self, error, error_description=None):
|
||||
def __init__(self, error: str, error_description: Optional[str] = None):
|
||||
self.error = error
|
||||
self.error_description = error_description
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
if self.error_description:
|
||||
return f"{self.error}: {self.error_description}"
|
||||
return self.error
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class CasResponse:
|
||||
username = attr.ib(type=str)
|
||||
attributes = attr.ib(type=Dict[str, List[Optional[str]]])
|
||||
username: str
|
||||
attributes: Dict[str, List[Optional[str]]]
|
||||
|
||||
|
||||
class CasHandler:
|
||||
|
@ -133,11 +133,9 @@ class CasHandler:
|
|||
body = pde.response
|
||||
except HttpResponseException as e:
|
||||
description = (
|
||||
(
|
||||
'Authorization server responded with a "{status}" error '
|
||||
"while exchanging the authorization code."
|
||||
).format(status=e.code),
|
||||
)
|
||||
'Authorization server responded with a "{status}" error '
|
||||
"while exchanging the authorization code."
|
||||
).format(status=e.code)
|
||||
raise CasError("server_error", description) from e
|
||||
|
||||
return self._parse_cas_response(body)
|
||||
|
|
|
@ -267,7 +267,7 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
|
||||
hs.get_distributor().observe("user_left_room", self.user_left_room)
|
||||
|
||||
def _check_device_name_length(self, name: Optional[str]):
|
||||
def _check_device_name_length(self, name: Optional[str]) -> None:
|
||||
"""
|
||||
Checks whether a device name is longer than the maximum allowed length.
|
||||
|
||||
|
|
|
@ -202,7 +202,7 @@ class E2eKeysHandler:
|
|||
|
||||
# Now fetch any devices that we don't have in our cache
|
||||
@trace
|
||||
async def do_remote_query(destination):
|
||||
async def do_remote_query(destination: str) -> None:
|
||||
"""This is called when we are querying the device list of a user on
|
||||
a remote homeserver and their device list is not in the device list
|
||||
cache. If we share a room with this user and we're not querying for
|
||||
|
@ -447,7 +447,7 @@ class E2eKeysHandler:
|
|||
}
|
||||
|
||||
@trace
|
||||
async def claim_client_keys(destination):
|
||||
async def claim_client_keys(destination: str) -> None:
|
||||
set_tag("destination", destination)
|
||||
device_keys = remote_queries[destination]
|
||||
try:
|
||||
|
|
|
@ -25,6 +25,7 @@ from synapse.api.errors import AuthError, Codes, SynapseError
|
|||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.builder import EventBuilder
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.types import StateMap, get_domain_from_id
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
|
@ -45,7 +46,11 @@ class EventAuthHandler:
|
|||
self._server_name = hs.hostname
|
||||
|
||||
async def check_from_context(
|
||||
self, room_version: str, event, context, do_sig_check=True
|
||||
self,
|
||||
room_version: str,
|
||||
event: EventBase,
|
||||
context: EventContext,
|
||||
do_sig_check: bool = True,
|
||||
) -> None:
|
||||
auth_event_ids = event.auth_event_ids()
|
||||
auth_events_by_id = await self._store.get_events(auth_event_ids)
|
||||
|
|
|
@ -1221,136 +1221,6 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
return missing_events
|
||||
|
||||
async def construct_auth_difference(
|
||||
self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase]
|
||||
) -> Dict:
|
||||
"""Given a local and remote auth chain, find the differences. This
|
||||
assumes that we have already processed all events in remote_auth
|
||||
|
||||
Params:
|
||||
local_auth
|
||||
remote_auth
|
||||
|
||||
Returns:
|
||||
dict
|
||||
"""
|
||||
|
||||
logger.debug("construct_auth_difference Start!")
|
||||
|
||||
# TODO: Make sure we are OK with local_auth or remote_auth having more
|
||||
# auth events in them than strictly necessary.
|
||||
|
||||
def sort_fun(ev):
|
||||
return ev.depth, ev.event_id
|
||||
|
||||
logger.debug("construct_auth_difference after sort_fun!")
|
||||
|
||||
# We find the differences by starting at the "bottom" of each list
|
||||
# and iterating up on both lists. The lists are ordered by depth and
|
||||
# then event_id, we iterate up both lists until we find the event ids
|
||||
# don't match. Then we look at depth/event_id to see which side is
|
||||
# missing that event, and iterate only up that list. Repeat.
|
||||
|
||||
remote_list = list(remote_auth)
|
||||
remote_list.sort(key=sort_fun)
|
||||
|
||||
local_list = list(local_auth)
|
||||
local_list.sort(key=sort_fun)
|
||||
|
||||
local_iter = iter(local_list)
|
||||
remote_iter = iter(remote_list)
|
||||
|
||||
logger.debug("construct_auth_difference before get_next!")
|
||||
|
||||
def get_next(it, opt=None):
|
||||
try:
|
||||
return next(it)
|
||||
except Exception:
|
||||
return opt
|
||||
|
||||
current_local = get_next(local_iter)
|
||||
current_remote = get_next(remote_iter)
|
||||
|
||||
logger.debug("construct_auth_difference before while")
|
||||
|
||||
missing_remotes = []
|
||||
missing_locals = []
|
||||
while current_local or current_remote:
|
||||
if current_remote is None:
|
||||
missing_locals.append(current_local)
|
||||
current_local = get_next(local_iter)
|
||||
continue
|
||||
|
||||
if current_local is None:
|
||||
missing_remotes.append(current_remote)
|
||||
current_remote = get_next(remote_iter)
|
||||
continue
|
||||
|
||||
if current_local.event_id == current_remote.event_id:
|
||||
current_local = get_next(local_iter)
|
||||
current_remote = get_next(remote_iter)
|
||||
continue
|
||||
|
||||
if current_local.depth < current_remote.depth:
|
||||
missing_locals.append(current_local)
|
||||
current_local = get_next(local_iter)
|
||||
continue
|
||||
|
||||
if current_local.depth > current_remote.depth:
|
||||
missing_remotes.append(current_remote)
|
||||
current_remote = get_next(remote_iter)
|
||||
continue
|
||||
|
||||
# They have the same depth, so we fall back to the event_id order
|
||||
if current_local.event_id < current_remote.event_id:
|
||||
missing_locals.append(current_local)
|
||||
current_local = get_next(local_iter)
|
||||
|
||||
if current_local.event_id > current_remote.event_id:
|
||||
missing_remotes.append(current_remote)
|
||||
current_remote = get_next(remote_iter)
|
||||
continue
|
||||
|
||||
logger.debug("construct_auth_difference after while")
|
||||
|
||||
# missing locals should be sent to the server
|
||||
# We should find why we are missing remotes, as they will have been
|
||||
# rejected.
|
||||
|
||||
# Remove events from missing_remotes if they are referencing a missing
|
||||
# remote. We only care about the "root" rejected ones.
|
||||
missing_remote_ids = [e.event_id for e in missing_remotes]
|
||||
base_remote_rejected = list(missing_remotes)
|
||||
for e in missing_remotes:
|
||||
for e_id in e.auth_event_ids():
|
||||
if e_id in missing_remote_ids:
|
||||
try:
|
||||
base_remote_rejected.remove(e)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
reason_map = {}
|
||||
|
||||
for e in base_remote_rejected:
|
||||
reason = await self.store.get_rejection_reason(e.event_id)
|
||||
if reason is None:
|
||||
# TODO: e is not in the current state, so we should
|
||||
# construct some proof of that.
|
||||
continue
|
||||
|
||||
reason_map[e.event_id] = reason
|
||||
|
||||
logger.debug("construct_auth_difference returning")
|
||||
|
||||
return {
|
||||
"auth_chain": local_auth,
|
||||
"rejects": {
|
||||
e.event_id: {"reason": reason_map[e.event_id], "proof": None}
|
||||
for e in base_remote_rejected
|
||||
},
|
||||
"missing": [e.event_id for e in missing_locals],
|
||||
}
|
||||
|
||||
@log_function
|
||||
async def exchange_third_party_invite(
|
||||
self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict
|
||||
|
|
|
@ -1016,7 +1016,7 @@ class FederationEventHandler:
|
|||
except Exception:
|
||||
logger.exception("Failed to resync device for %s", sender)
|
||||
|
||||
async def _handle_marker_event(self, origin: str, marker_event: EventBase):
|
||||
async def _handle_marker_event(self, origin: str, marker_event: EventBase) -> None:
|
||||
"""Handles backfilling the insertion event when we receive a marker
|
||||
event that points to one.
|
||||
|
||||
|
@ -1109,7 +1109,7 @@ class FederationEventHandler:
|
|||
|
||||
event_map: Dict[str, EventBase] = {}
|
||||
|
||||
async def get_event(event_id: str):
|
||||
async def get_event(event_id: str) -> None:
|
||||
with nested_logging_context(event_id):
|
||||
try:
|
||||
event = await self._federation_client.get_pdu(
|
||||
|
@ -1218,7 +1218,7 @@ class FederationEventHandler:
|
|||
if not event_infos:
|
||||
return
|
||||
|
||||
async def prep(ev_info: _NewEventInfo):
|
||||
async def prep(ev_info: _NewEventInfo) -> EventContext:
|
||||
event = ev_info.event
|
||||
with nested_logging_context(suffix=event.event_id):
|
||||
res = await self._state_handler.compute_event_context(event)
|
||||
|
@ -1692,7 +1692,7 @@ class FederationEventHandler:
|
|||
|
||||
async def _run_push_actions_and_persist_event(
|
||||
self, event: EventBase, context: EventContext, backfilled: bool = False
|
||||
):
|
||||
) -> None:
|
||||
"""Run the push actions for a received event, and persist it.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Set
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Iterable, List, Set
|
||||
|
||||
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
|
||||
from synapse.types import GroupID, JsonDict, get_domain_from_id
|
||||
|
@ -25,12 +25,14 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_rerouter(func_name):
|
||||
def _create_rerouter(func_name: str) -> Callable[..., Awaitable[JsonDict]]:
|
||||
"""Returns an async function that looks at the group id and calls the function
|
||||
on federation or the local group server if the group is local
|
||||
"""
|
||||
|
||||
async def f(self, group_id, *args, **kwargs):
|
||||
async def f(
|
||||
self: "GroupsLocalWorkerHandler", group_id: str, *args: Any, **kwargs: Any
|
||||
) -> JsonDict:
|
||||
if not GroupID.is_valid(group_id):
|
||||
raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -150,7 +150,7 @@ class InitialSyncHandler(BaseHandler):
|
|||
if limit is None:
|
||||
limit = 10
|
||||
|
||||
async def handle_room(event: RoomsForUser):
|
||||
async def handle_room(event: RoomsForUser) -> None:
|
||||
d: JsonDict = {
|
||||
"room_id": event.room_id,
|
||||
"membership": event.membership,
|
||||
|
@ -411,7 +411,7 @@ class InitialSyncHandler(BaseHandler):
|
|||
|
||||
presence_handler = self.hs.get_presence_handler()
|
||||
|
||||
async def get_presence():
|
||||
async def get_presence() -> List[JsonDict]:
|
||||
# If presence is disabled, return an empty list
|
||||
if not self.hs.config.server.use_presence:
|
||||
return []
|
||||
|
@ -428,7 +428,7 @@ class InitialSyncHandler(BaseHandler):
|
|||
for s in states
|
||||
]
|
||||
|
||||
async def get_receipts():
|
||||
async def get_receipts() -> List[JsonDict]:
|
||||
receipts = await self.store.get_linearized_receipts_for_room(
|
||||
room_id, to_key=now_token.receipt_key
|
||||
)
|
||||
|
|
|
@ -46,6 +46,7 @@ from synapse.events import EventBase
|
|||
from synapse.events.builder import EventBuilder
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.events.validator import EventValidator
|
||||
from synapse.handlers.directory import DirectoryHandler
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
|
||||
|
@ -298,7 +299,7 @@ class MessageHandler:
|
|||
for user_id, profile in users_with_profile.items()
|
||||
}
|
||||
|
||||
def maybe_schedule_expiry(self, event: EventBase):
|
||||
def maybe_schedule_expiry(self, event: EventBase) -> None:
|
||||
"""Schedule the expiry of an event if there's not already one scheduled,
|
||||
or if the one running is for an event that will expire after the provided
|
||||
timestamp.
|
||||
|
@ -318,7 +319,7 @@ class MessageHandler:
|
|||
# a task scheduled for a timestamp that's sooner than the provided one.
|
||||
self._schedule_expiry_for_event(event.event_id, expiry_ts)
|
||||
|
||||
async def _schedule_next_expiry(self):
|
||||
async def _schedule_next_expiry(self) -> None:
|
||||
"""Retrieve the ID and the expiry timestamp of the next event to be expired,
|
||||
and schedule an expiry task for it.
|
||||
|
||||
|
@ -331,7 +332,7 @@ class MessageHandler:
|
|||
event_id, expiry_ts = res
|
||||
self._schedule_expiry_for_event(event_id, expiry_ts)
|
||||
|
||||
def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int):
|
||||
def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int) -> None:
|
||||
"""Schedule an expiry task for the provided event if there's not already one
|
||||
scheduled at a timestamp that's sooner than the provided one.
|
||||
|
||||
|
@ -367,7 +368,7 @@ class MessageHandler:
|
|||
event_id,
|
||||
)
|
||||
|
||||
async def _expire_event(self, event_id: str):
|
||||
async def _expire_event(self, event_id: str) -> None:
|
||||
"""Retrieve and expire an event that needs to be expired from the database.
|
||||
|
||||
If the event doesn't exist in the database, log it and delete the expiry date
|
||||
|
@ -1229,7 +1230,10 @@ class EventCreationHandler:
|
|||
self._external_cache_joined_hosts_updates[state_entry.state_group] = None
|
||||
|
||||
async def _validate_canonical_alias(
|
||||
self, directory_handler, room_alias_str: str, expected_room_id: str
|
||||
self,
|
||||
directory_handler: DirectoryHandler,
|
||||
room_alias_str: str,
|
||||
expected_room_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Ensure that the given room alias points to the expected room ID.
|
||||
|
@ -1477,7 +1481,7 @@ class EventCreationHandler:
|
|||
# If there's an expiry timestamp on the event, schedule its expiry.
|
||||
self._message_handler.maybe_schedule_expiry(event)
|
||||
|
||||
def _notify():
|
||||
def _notify() -> None:
|
||||
try:
|
||||
self.notifier.on_new_room_event(
|
||||
event, event_pos, max_stream_token, extra_users=extra_users
|
||||
|
@ -1523,7 +1527,7 @@ class EventCreationHandler:
|
|||
except Exception:
|
||||
logger.exception("Error bumping presence active time")
|
||||
|
||||
async def _send_dummy_events_to_fill_extremities(self):
|
||||
async def _send_dummy_events_to_fill_extremities(self) -> None:
|
||||
"""Background task to send dummy events into rooms that have a large
|
||||
number of extremities
|
||||
"""
|
||||
|
@ -1600,7 +1604,7 @@ class EventCreationHandler:
|
|||
)
|
||||
return False
|
||||
|
||||
def _expire_rooms_to_exclude_from_dummy_event_insertion(self):
|
||||
def _expire_rooms_to_exclude_from_dummy_event_insertion(self) -> None:
|
||||
expire_before = self.clock.time_msec() - _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY
|
||||
to_expire = set()
|
||||
for room_id, time in self._rooms_to_exclude_from_dummy_event_insertion.items():
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
import attr
|
||||
|
@ -249,11 +249,11 @@ class OidcHandler:
|
|||
class OidcError(Exception):
|
||||
"""Used to catch errors when calling the token_endpoint"""
|
||||
|
||||
def __init__(self, error, error_description=None):
|
||||
def __init__(self, error: str, error_description: Optional[str] = None):
|
||||
self.error = error
|
||||
self.error_description = error_description
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
if self.error_description:
|
||||
return f"{self.error}: {self.error_description}"
|
||||
return self.error
|
||||
|
@ -1057,13 +1057,13 @@ class JwtClientSecret:
|
|||
self._cached_secret = b""
|
||||
self._cached_secret_replacement_time = 0
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
# if client_auth_method is client_secret_basic, then ClientAuth.prepare calls
|
||||
# encode_client_secret_basic, which calls "{}".format(secret), which ends up
|
||||
# here.
|
||||
return self._get_secret().decode("ascii")
|
||||
|
||||
def __bytes__(self):
|
||||
def __bytes__(self) -> bytes:
|
||||
# if client_auth_method is client_secret_post, then ClientAuth.prepare calls
|
||||
# encode_client_secret_post, which ends up here.
|
||||
return self._get_secret()
|
||||
|
@ -1197,21 +1197,21 @@ class OidcSessionTokenGenerator:
|
|||
)
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True)
|
||||
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
||||
class OidcSessionData:
|
||||
"""The attributes which are stored in a OIDC session cookie"""
|
||||
|
||||
# the Identity Provider being used
|
||||
idp_id = attr.ib(type=str)
|
||||
idp_id: str
|
||||
|
||||
# The `nonce` parameter passed to the OIDC provider.
|
||||
nonce = attr.ib(type=str)
|
||||
nonce: str
|
||||
|
||||
# The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
|
||||
client_redirect_url = attr.ib(type=str)
|
||||
client_redirect_url: str
|
||||
|
||||
# The session ID of the ongoing UI Auth ("" if this is a login)
|
||||
ui_auth_session_id = attr.ib(type=str)
|
||||
ui_auth_session_id: str
|
||||
|
||||
|
||||
class UserAttributeDict(TypedDict):
|
||||
|
@ -1290,20 +1290,20 @@ class OidcMappingProvider(Generic[C]):
|
|||
|
||||
|
||||
# Used to clear out "None" values in templates
|
||||
def jinja_finalize(thing):
|
||||
def jinja_finalize(thing: Any) -> Any:
|
||||
return thing if thing is not None else ""
|
||||
|
||||
|
||||
env = Environment(finalize=jinja_finalize)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class JinjaOidcMappingConfig:
|
||||
subject_claim = attr.ib(type=str)
|
||||
localpart_template = attr.ib(type=Optional[Template])
|
||||
display_name_template = attr.ib(type=Optional[Template])
|
||||
email_template = attr.ib(type=Optional[Template])
|
||||
extra_attributes = attr.ib(type=Dict[str, Template])
|
||||
subject_claim: str
|
||||
localpart_template: Optional[Template]
|
||||
display_name_template: Optional[Template]
|
||||
email_template: Optional[Template]
|
||||
extra_attributes: Dict[str, Template]
|
||||
|
||||
|
||||
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Set
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
|
@ -24,7 +26,7 @@ from synapse.logging.context import run_in_background
|
|||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import Requester
|
||||
from synapse.types import JsonDict, Requester
|
||||
from synapse.util.async_helpers import ReadWriteLock
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
@ -36,15 +38,12 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class PurgeStatus:
|
||||
"""Object tracking the status of a purge request
|
||||
|
||||
This class contains information on the progress of a purge request, for
|
||||
return by get_purge_status.
|
||||
|
||||
Attributes:
|
||||
status (int): Tracks whether this request has completed. One of
|
||||
STATUS_{ACTIVE,COMPLETE,FAILED}
|
||||
"""
|
||||
|
||||
STATUS_ACTIVE = 0
|
||||
|
@ -57,10 +56,10 @@ class PurgeStatus:
|
|||
STATUS_FAILED: "failed",
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.status = PurgeStatus.STATUS_ACTIVE
|
||||
# Tracks whether this request has completed. One of STATUS_{ACTIVE,COMPLETE,FAILED}.
|
||||
status: int = STATUS_ACTIVE
|
||||
|
||||
def asdict(self):
|
||||
def asdict(self) -> JsonDict:
|
||||
return {"status": PurgeStatus.STATUS_TEXT[self.status]}
|
||||
|
||||
|
||||
|
@ -107,7 +106,7 @@ class PaginationHandler:
|
|||
|
||||
async def purge_history_for_rooms_in_range(
|
||||
self, min_ms: Optional[int], max_ms: Optional[int]
|
||||
):
|
||||
) -> None:
|
||||
"""Purge outdated events from rooms within the given retention range.
|
||||
|
||||
If a default retention policy is defined in the server's configuration and its
|
||||
|
@ -291,7 +290,7 @@ class PaginationHandler:
|
|||
self._purges_in_progress_by_room.discard(room_id)
|
||||
|
||||
# remove the purge from the list 24 hours after it completes
|
||||
def clear_purge():
|
||||
def clear_purge() -> None:
|
||||
del self._purges_by_id[purge_id]
|
||||
|
||||
self.hs.get_reactor().callLater(24 * 3600, clear_purge)
|
||||
|
|
|
@ -26,18 +26,22 @@ import contextlib
|
|||
import logging
|
||||
from bisect import bisect
|
||||
from contextlib import contextmanager
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
|
@ -240,7 +244,7 @@ class BasePresenceHandler(abc.ABC):
|
|||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def bump_presence_active_time(self, user: UserID):
|
||||
async def bump_presence_active_time(self, user: UserID) -> None:
|
||||
"""We've seen the user do something that indicates they're interacting
|
||||
with the app.
|
||||
"""
|
||||
|
@ -274,7 +278,7 @@ class BasePresenceHandler(abc.ABC):
|
|||
|
||||
async def process_replication_rows(
|
||||
self, stream_name: str, instance_name: str, token: int, rows: list
|
||||
):
|
||||
) -> None:
|
||||
"""Process streams received over replication."""
|
||||
await self._federation_queue.process_replication_rows(
|
||||
stream_name, instance_name, token, rows
|
||||
|
@ -286,7 +290,7 @@ class BasePresenceHandler(abc.ABC):
|
|||
|
||||
async def maybe_send_presence_to_interested_destinations(
|
||||
self, states: List[UserPresenceState]
|
||||
):
|
||||
) -> None:
|
||||
"""If this instance is a federation sender, send the states to all
|
||||
destinations that are interested. Filters out any states for remote
|
||||
users.
|
||||
|
@ -309,7 +313,7 @@ class BasePresenceHandler(abc.ABC):
|
|||
for destination, host_states in hosts_to_states.items():
|
||||
self._federation.send_presence_to_destinations(host_states, [destination])
|
||||
|
||||
async def send_full_presence_to_users(self, user_ids: Collection[str]):
|
||||
async def send_full_presence_to_users(self, user_ids: Collection[str]) -> None:
|
||||
"""
|
||||
Adds to the list of users who should receive a full snapshot of presence
|
||||
upon their next sync. Note that this only works for local users.
|
||||
|
@ -363,7 +367,12 @@ class BasePresenceHandler(abc.ABC):
|
|||
class _NullContextManager(ContextManager[None]):
|
||||
"""A context manager which does nothing."""
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
|
@ -468,7 +477,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
|
|||
if self._user_to_num_current_syncs[user_id] == 1:
|
||||
self.mark_as_coming_online(user_id)
|
||||
|
||||
def _end():
|
||||
def _end() -> None:
|
||||
# We check that the user_id is in user_to_num_current_syncs because
|
||||
# user_to_num_current_syncs may have been cleared if we are
|
||||
# shutting down.
|
||||
|
@ -480,7 +489,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
|
|||
self.mark_as_going_offline(user_id)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _user_syncing():
|
||||
def _user_syncing() -> Generator[None, None, None]:
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
|
@ -503,7 +512,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
|
|||
|
||||
async def process_replication_rows(
|
||||
self, stream_name: str, instance_name: str, token: int, rows: list
|
||||
):
|
||||
) -> None:
|
||||
await super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
if stream_name != PresenceStream.NAME:
|
||||
|
@ -689,7 +698,7 @@ class PresenceHandler(BasePresenceHandler):
|
|||
# Start a LoopingCall in 30s that fires every 5s.
|
||||
# The initial delay is to allow disconnected clients a chance to
|
||||
# reconnect before we treat them as offline.
|
||||
def run_timeout_handler():
|
||||
def run_timeout_handler() -> Awaitable[None]:
|
||||
return run_as_background_process(
|
||||
"handle_presence_timeouts", self._handle_timeouts
|
||||
)
|
||||
|
@ -698,7 +707,7 @@ class PresenceHandler(BasePresenceHandler):
|
|||
30, self.clock.looping_call, run_timeout_handler, 5000
|
||||
)
|
||||
|
||||
def run_persister():
|
||||
def run_persister() -> Awaitable[None]:
|
||||
return run_as_background_process(
|
||||
"persist_presence_changes", self._persist_unpersisted_changes
|
||||
)
|
||||
|
@ -942,8 +951,8 @@ class PresenceHandler(BasePresenceHandler):
|
|||
when users disconnect/reconnect.
|
||||
|
||||
Args:
|
||||
user_id (str)
|
||||
affect_presence (bool): If false this function will be a no-op.
|
||||
user_id
|
||||
affect_presence: If false this function will be a no-op.
|
||||
Useful for streams that are not associated with an actual
|
||||
client that is being used by a user.
|
||||
"""
|
||||
|
@ -978,7 +987,7 @@ class PresenceHandler(BasePresenceHandler):
|
|||
]
|
||||
)
|
||||
|
||||
async def _end():
|
||||
async def _end() -> None:
|
||||
try:
|
||||
self.user_to_num_current_syncs[user_id] -= 1
|
||||
|
||||
|
@ -994,7 +1003,7 @@ class PresenceHandler(BasePresenceHandler):
|
|||
logger.exception("Error updating presence after sync")
|
||||
|
||||
@contextmanager
|
||||
def _user_syncing():
|
||||
def _user_syncing() -> Generator[None, None, None]:
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
|
@ -1264,7 +1273,7 @@ class PresenceHandler(BasePresenceHandler):
|
|||
if self._event_processing:
|
||||
return
|
||||
|
||||
async def _process_presence():
|
||||
async def _process_presence() -> None:
|
||||
assert not self._event_processing
|
||||
|
||||
self._event_processing = True
|
||||
|
@ -1513,7 +1522,7 @@ class PresenceEventSource:
|
|||
room_ids: Optional[List[str]] = None,
|
||||
include_offline: bool = True,
|
||||
explicit_room_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[List[UserPresenceState], int]:
|
||||
# The process for getting presence events are:
|
||||
# 1. Get the rooms the user is in.
|
||||
|
@ -2074,7 +2083,7 @@ class PresenceFederationQueue:
|
|||
if self._queue_presence_updates:
|
||||
self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS)
|
||||
|
||||
def _clear_queue(self):
|
||||
def _clear_queue(self) -> None:
|
||||
"""Clear out older entries from the queue."""
|
||||
clear_before = self._clock.time_msec() - self._KEEP_ITEMS_IN_QUEUE_FOR_MS
|
||||
|
||||
|
@ -2205,7 +2214,7 @@ class PresenceFederationQueue:
|
|||
|
||||
async def process_replication_rows(
|
||||
self, stream_name: str, instance_name: str, token: int, rows: list
|
||||
):
|
||||
) -> None:
|
||||
if stream_name != PresenceFederationStream.NAME:
|
||||
return
|
||||
|
||||
|
|
|
@ -254,7 +254,7 @@ class ProfileHandler(BaseHandler):
|
|||
requester: Requester,
|
||||
new_avatar_url: str,
|
||||
by_admin: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
"""Set a new avatar URL for a user.
|
||||
|
||||
Args:
|
||||
|
@ -425,7 +425,7 @@ class ProfileHandler(BaseHandler):
|
|||
raise
|
||||
|
||||
@wrap_as_background_process("Update remote profile")
|
||||
async def _update_remote_profile_cache(self):
|
||||
async def _update_remote_profile_cache(self) -> None:
|
||||
"""Called periodically to check profiles of remote users we haven't
|
||||
checked in a while.
|
||||
"""
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import ReadReceiptEventFields
|
||||
from synapse.appservice import ApplicationService
|
||||
|
@ -216,7 +216,7 @@ class ReceiptEventSource:
|
|||
return visible_events
|
||||
|
||||
async def get_new_events(
|
||||
self, from_key: int, room_ids: List[str], user: UserID, **kwargs
|
||||
self, from_key: int, room_ids: List[str], user: UserID, **kwargs: Any
|
||||
) -> Tuple[List[JsonDict], int]:
|
||||
from_key = int(from_key)
|
||||
to_key = self.get_current_key()
|
||||
|
|
|
@ -125,7 +125,7 @@ class RegistrationHandler(BaseHandler):
|
|||
localpart: str,
|
||||
guest_access_token: Optional[str] = None,
|
||||
assigned_user_id: Optional[str] = None,
|
||||
):
|
||||
) -> None:
|
||||
if types.contains_invalid_mxid_characters(localpart):
|
||||
raise SynapseError(
|
||||
400,
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
# Copyright 2014 - 2016 OpenMarket Ltd
|
||||
# Copyright 2018-2019 New Vector Ltd
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -186,7 +184,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
|
||||
async def _upgrade_room(
|
||||
self, requester: Requester, old_room_id: str, new_version: RoomVersion
|
||||
):
|
||||
) -> str:
|
||||
"""
|
||||
Args:
|
||||
requester: the user requesting the upgrade
|
||||
|
@ -512,7 +510,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
old_room_id: str,
|
||||
new_room_id: str,
|
||||
old_room_state: StateMap[str],
|
||||
):
|
||||
) -> None:
|
||||
# check to see if we have a canonical alias.
|
||||
canonical_alias_event = None
|
||||
canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, ""))
|
||||
|
@ -902,7 +900,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
|
||||
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
|
||||
|
||||
def create(etype: str, content: JsonDict, **kwargs) -> JsonDict:
|
||||
def create(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict:
|
||||
e = {"type": etype, "content": content}
|
||||
|
||||
e.update(event_keys)
|
||||
|
@ -910,7 +908,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
|
||||
return e
|
||||
|
||||
async def send(etype: str, content: JsonDict, **kwargs) -> int:
|
||||
async def send(etype: str, content: JsonDict, **kwargs: Any) -> int:
|
||||
event = create(etype, content, **kwargs)
|
||||
logger.debug("Sending %s in new room", etype)
|
||||
# Allow these events to be sent even if the user is shadow-banned to
|
||||
|
@ -1033,7 +1031,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
creator_id: str,
|
||||
is_public: bool,
|
||||
room_version: RoomVersion,
|
||||
):
|
||||
) -> str:
|
||||
# autogen room IDs and try to create it. We may clash, so just
|
||||
# try a few times till one goes through, giving up eventually.
|
||||
attempts = 0
|
||||
|
@ -1097,7 +1095,7 @@ class RoomContextHandler:
|
|||
users = await self.store.get_users_in_room(room_id)
|
||||
is_peeking = user.to_string() not in users
|
||||
|
||||
async def filter_evts(events):
|
||||
async def filter_evts(events: List[EventBase]) -> List[EventBase]:
|
||||
if use_admin_priviledge:
|
||||
return events
|
||||
return await filter_events_for_client(
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Optional, Tuple
|
||||
|
||||
import msgpack
|
||||
from unpaddedbase64 import decode_base64, encode_base64
|
||||
|
@ -33,7 +33,7 @@ from synapse.api.errors import (
|
|||
SynapseError,
|
||||
)
|
||||
from synapse.types import JsonDict, ThirdPartyInstanceID
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.caches.descriptors import _CacheContext, cached
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
@ -169,7 +169,7 @@ class RoomListHandler(BaseHandler):
|
|||
ignore_non_federatable=from_federation,
|
||||
)
|
||||
|
||||
def build_room_entry(room):
|
||||
def build_room_entry(room: JsonDict) -> JsonDict:
|
||||
entry = {
|
||||
"room_id": room["room_id"],
|
||||
"name": room["name"],
|
||||
|
@ -249,10 +249,10 @@ class RoomListHandler(BaseHandler):
|
|||
self,
|
||||
room_id: str,
|
||||
num_joined_users: int,
|
||||
cache_context,
|
||||
cache_context: _CacheContext,
|
||||
with_alias: bool = True,
|
||||
allow_private: bool = False,
|
||||
) -> Optional[dict]:
|
||||
) -> Optional[JsonDict]:
|
||||
"""Returns the entry for a room
|
||||
|
||||
Args:
|
||||
|
@ -507,7 +507,7 @@ class RoomListNextBatch(
|
|||
)
|
||||
)
|
||||
|
||||
def copy_and_replace(self, **kwds) -> "RoomListNextBatch":
|
||||
def copy_and_replace(self, **kwds: Any) -> "RoomListNextBatch":
|
||||
return self._replace(**kwds)
|
||||
|
||||
|
||||
|
|
|
@ -225,7 +225,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
room_id: Optional[str],
|
||||
n_invites: int,
|
||||
update: bool = True,
|
||||
):
|
||||
) -> None:
|
||||
"""Ratelimit more than one invite sent by the given requester in the given room.
|
||||
|
||||
Args:
|
||||
|
@ -249,7 +249,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
requester: Optional[Requester],
|
||||
room_id: Optional[str],
|
||||
invitee_user_id: str,
|
||||
):
|
||||
) -> None:
|
||||
"""Ratelimit invites by room and by target user.
|
||||
|
||||
If room ID is missing then we just rate limit by target user.
|
||||
|
@ -386,7 +386,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
return result_event.event_id, result_event.internal_metadata.stream_ordering
|
||||
|
||||
async def copy_room_tags_and_direct_to_room(
|
||||
self, old_room_id, new_room_id, user_id
|
||||
self, old_room_id: str, new_room_id: str, user_id: str
|
||||
) -> None:
|
||||
"""Copies the tags and direct room state from one room to another.
|
||||
|
||||
|
@ -1030,7 +1030,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
event: EventBase,
|
||||
context: EventContext,
|
||||
ratelimit: bool = True,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Change the membership status of a user in a room.
|
||||
|
||||
|
|
|
@ -541,7 +541,7 @@ class RoomSummaryHandler:
|
|||
origin: str,
|
||||
requested_room_id: str,
|
||||
suggested_only: bool,
|
||||
):
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Implementation of the room hierarchy Federation API.
|
||||
|
||||
|
|
|
@ -40,15 +40,15 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class Saml2SessionData:
|
||||
"""Data we track about SAML2 sessions"""
|
||||
|
||||
# time the session was created, in milliseconds
|
||||
creation_time = attr.ib()
|
||||
creation_time: int
|
||||
# The user interactive authentication session ID associated with this SAML
|
||||
# session (or None if this SAML session is for an initial login).
|
||||
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
|
||||
ui_auth_session_id: Optional[str] = None
|
||||
|
||||
|
||||
class SamlHandler(BaseHandler):
|
||||
|
@ -359,7 +359,7 @@ class SamlHandler(BaseHandler):
|
|||
|
||||
return remote_user_id
|
||||
|
||||
def expire_sessions(self):
|
||||
def expire_sessions(self) -> None:
|
||||
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
|
||||
to_expire = set()
|
||||
for reqid, data in self._outstanding_requests_dict.items():
|
||||
|
@ -391,10 +391,10 @@ MXID_MAPPER_MAP: Dict[str, Callable[[str], str]] = {
|
|||
}
|
||||
|
||||
|
||||
@attr.s
|
||||
@attr.s(auto_attribs=True)
|
||||
class SamlConfig:
|
||||
mxid_source_attribute = attr.ib()
|
||||
mxid_mapper = attr.ib()
|
||||
mxid_source_attribute: str
|
||||
mxid_mapper: Callable[[str], str]
|
||||
|
||||
|
||||
class DefaultSamlMappingProvider:
|
||||
|
|
|
@ -17,7 +17,7 @@ import logging
|
|||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pkg_resources import parse_version
|
||||
|
||||
|
@ -79,7 +79,7 @@ async def _sendmail(
|
|||
msg = BytesIO(msg_bytes)
|
||||
d: "Deferred[object]" = Deferred()
|
||||
|
||||
def build_sender_factory(**kwargs) -> ESMTPSenderFactory:
|
||||
def build_sender_factory(**kwargs: Any) -> ESMTPSenderFactory:
|
||||
return ESMTPSenderFactory(
|
||||
username,
|
||||
password,
|
||||
|
|
|
@ -205,7 +205,7 @@ class SsoHandler:
|
|||
|
||||
self._consent_at_registration = hs.config.consent.user_consent_at_registration
|
||||
|
||||
def register_identity_provider(self, p: SsoIdentityProvider):
|
||||
def register_identity_provider(self, p: SsoIdentityProvider) -> None:
|
||||
p_id = p.idp_id
|
||||
assert p_id not in self._identity_providers
|
||||
self._identity_providers[p_id] = p
|
||||
|
@ -856,7 +856,7 @@ class SsoHandler:
|
|||
|
||||
async def handle_terms_accepted(
|
||||
self, request: Request, session_id: str, terms_version: str
|
||||
):
|
||||
) -> None:
|
||||
"""Handle a request to the new-user 'consent' endpoint
|
||||
|
||||
Will serve an HTTP response to the request.
|
||||
|
@ -959,7 +959,7 @@ class SsoHandler:
|
|||
new_user=True,
|
||||
)
|
||||
|
||||
def _expire_old_sessions(self):
|
||||
def _expire_old_sessions(self) -> None:
|
||||
to_expire = []
|
||||
now = int(self._clock.time_msec())
|
||||
|
||||
|
|
|
@ -68,7 +68,7 @@ class StatsHandler:
|
|||
|
||||
self._is_processing = True
|
||||
|
||||
async def process():
|
||||
async def process() -> None:
|
||||
try:
|
||||
await self._unsafe_process()
|
||||
finally:
|
||||
|
|
|
@ -364,7 +364,9 @@ class SyncHandler:
|
|||
)
|
||||
else:
|
||||
|
||||
async def current_sync_callback(before_token, after_token) -> SyncResult:
|
||||
async def current_sync_callback(
|
||||
before_token: StreamToken, after_token: StreamToken
|
||||
) -> SyncResult:
|
||||
return await self.current_sync_for_user(sync_config, since_token)
|
||||
|
||||
result = await self.notifier.wait_for_events(
|
||||
|
@ -1532,9 +1534,9 @@ class SyncHandler:
|
|||
newly_joined_rooms = room_changes.newly_joined_rooms
|
||||
newly_left_rooms = room_changes.newly_left_rooms
|
||||
|
||||
async def handle_room_entries(room_entry: "RoomSyncResultBuilder"):
|
||||
async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None:
|
||||
logger.debug("Generating room entry for %s", room_entry.room_id)
|
||||
res = await self._generate_room_entry(
|
||||
await self._generate_room_entry(
|
||||
sync_result_builder,
|
||||
ignored_users,
|
||||
room_entry,
|
||||
|
@ -1544,7 +1546,6 @@ class SyncHandler:
|
|||
always_include=sync_result_builder.full_state,
|
||||
)
|
||||
logger.debug("Generated room entry for %s", room_entry.room_id)
|
||||
return res
|
||||
|
||||
await concurrently_execute(handle_room_entries, room_entries, 10)
|
||||
|
||||
|
@ -1925,7 +1926,7 @@ class SyncHandler:
|
|||
tags: Optional[Dict[str, Dict[str, Any]]],
|
||||
account_data: Dict[str, JsonDict],
|
||||
always_include: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
"""Populates the `joined` and `archived` section of `sync_result_builder`
|
||||
based on the `room_builder`.
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
import logging
|
||||
import random
|
||||
from collections import namedtuple
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
|
||||
from synapse.appservice import ApplicationService
|
||||
|
@ -485,7 +485,7 @@ class TypingNotificationEventSource:
|
|||
return (events, handler._latest_room_serial)
|
||||
|
||||
async def get_new_events(
|
||||
self, from_key: int, room_ids: Iterable[str], **kwargs
|
||||
self, from_key: int, room_ids: Iterable[str], **kwargs: Any
|
||||
) -> Tuple[List[JsonDict], int]:
|
||||
with Measure(self.clock, "typing.get_new_events"):
|
||||
from_key = int(from_key)
|
||||
|
|
|
@ -70,7 +70,7 @@ class DummyAuthChecker(UserInteractiveAuthChecker):
|
|||
class TermsAuthChecker(UserInteractiveAuthChecker):
|
||||
AUTH_TYPE = LoginType.TERMS
|
||||
|
||||
def is_enabled(self):
|
||||
def is_enabled(self) -> bool:
|
||||
return True
|
||||
|
||||
async def check_auth(self, authdict: dict, clientip: str) -> Any:
|
||||
|
|
|
@ -114,7 +114,7 @@ class UserDirectoryHandler(StateDeltasHandler):
|
|||
if self._is_processing:
|
||||
return
|
||||
|
||||
async def process():
|
||||
async def process() -> None:
|
||||
try:
|
||||
await self._unsafe_process()
|
||||
finally:
|
||||
|
|
Loading…
Reference in a new issue