Add types to synapse.util. (#10601)

This commit is contained in:
reivilibre 2021-09-10 17:03:18 +01:00 committed by GitHub
parent ceab5a4bfa
commit 524b8ead77
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
41 changed files with 400 additions and 253 deletions

1
changelog.d/10601.misc Normal file
View file

@ -0,0 +1 @@
Add type annotations to the synapse.util package.

View file

@ -74,17 +74,7 @@ files =
synapse/storage/util, synapse/storage/util,
synapse/streams, synapse/streams,
synapse/types.py, synapse/types.py,
synapse/util/async_helpers.py, synapse/util,
synapse/util/caches,
synapse/util/daemonize.py,
synapse/util/hash.py,
synapse/util/iterutils.py,
synapse/util/linked_list.py,
synapse/util/metrics.py,
synapse/util/macaroons.py,
synapse/util/module_loader.py,
synapse/util/msisdn.py,
synapse/util/stringutils.py,
synapse/visibility.py, synapse/visibility.py,
tests/replication, tests/replication,
tests/test_event_auth.py, tests/test_event_auth.py,
@ -102,6 +92,69 @@ files =
[mypy-synapse.rest.client.*] [mypy-synapse.rest.client.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.util.batching_queue]
disallow_untyped_defs = True
[mypy-synapse.util.caches.dictionary_cache]
disallow_untyped_defs = True
[mypy-synapse.util.file_consumer]
disallow_untyped_defs = True
[mypy-synapse.util.frozenutils]
disallow_untyped_defs = True
[mypy-synapse.util.hash]
disallow_untyped_defs = True
[mypy-synapse.util.httpresourcetree]
disallow_untyped_defs = True
[mypy-synapse.util.iterutils]
disallow_untyped_defs = True
[mypy-synapse.util.linked_list]
disallow_untyped_defs = True
[mypy-synapse.util.logcontext]
disallow_untyped_defs = True
[mypy-synapse.util.logformatter]
disallow_untyped_defs = True
[mypy-synapse.util.macaroons]
disallow_untyped_defs = True
[mypy-synapse.util.manhole]
disallow_untyped_defs = True
[mypy-synapse.util.module_loader]
disallow_untyped_defs = True
[mypy-synapse.util.msisdn]
disallow_untyped_defs = True
[mypy-synapse.util.ratelimitutils]
disallow_untyped_defs = True
[mypy-synapse.util.retryutils]
disallow_untyped_defs = True
[mypy-synapse.util.rlimit]
disallow_untyped_defs = True
[mypy-synapse.util.stringutils]
disallow_untyped_defs = True
[mypy-synapse.util.templates]
disallow_untyped_defs = True
[mypy-synapse.util.threepids]
disallow_untyped_defs = True
[mypy-synapse.util.wheel_timer]
disallow_untyped_defs = True
[mypy-pymacaroons.*] [mypy-pymacaroons.*]
ignore_missing_imports = True ignore_missing_imports = True

View file

@ -73,4 +73,4 @@ class RedisFactory(protocol.ReconnectingClientFactory):
def buildProtocol(self, addr) -> RedisProtocol: ... def buildProtocol(self, addr) -> RedisProtocol: ...
class SubscriberFactory(RedisFactory): class SubscriberFactory(RedisFactory):
def __init__(self): ... def __init__(self) -> None: ...

View file

@ -46,7 +46,7 @@ class Ratelimiter:
# * How many times an action has occurred since a point in time # * How many times an action has occurred since a point in time
# * The point in time # * The point in time
# * The rate_hz of this particular entry. This can vary per request # * The rate_hz of this particular entry. This can vary per request
self.actions: OrderedDict[Hashable, Tuple[float, int, float]] = OrderedDict() self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict()
async def can_do_action( async def can_do_action(
self, self,
@ -56,7 +56,7 @@ class Ratelimiter:
burst_count: Optional[int] = None, burst_count: Optional[int] = None,
update: bool = True, update: bool = True,
n_actions: int = 1, n_actions: int = 1,
_time_now_s: Optional[int] = None, _time_now_s: Optional[float] = None,
) -> Tuple[bool, float]: ) -> Tuple[bool, float]:
"""Can the entity (e.g. user or IP address) perform the action? """Can the entity (e.g. user or IP address) perform the action?
@ -160,7 +160,7 @@ class Ratelimiter:
return allowed, time_allowed return allowed, time_allowed
def _prune_message_counts(self, time_now_s: int): def _prune_message_counts(self, time_now_s: float):
"""Remove message count entries that have not exceeded their defined """Remove message count entries that have not exceeded their defined
rate_hz limit rate_hz limit
@ -188,7 +188,7 @@ class Ratelimiter:
burst_count: Optional[int] = None, burst_count: Optional[int] = None,
update: bool = True, update: bool = True,
n_actions: int = 1, n_actions: int = 1,
_time_now_s: Optional[int] = None, _time_now_s: Optional[float] = None,
): ):
"""Checks if an action can be performed. If not, raises a LimitExceededError """Checks if an action can be performed. If not, raises a LimitExceededError

View file

@ -14,6 +14,8 @@
from typing import Dict, Optional from typing import Dict, Optional
import attr
from ._base import Config from ._base import Config
@ -29,18 +31,13 @@ class RateLimitConfig:
self.burst_count = int(config.get("burst_count", defaults["burst_count"])) self.burst_count = int(config.get("burst_count", defaults["burst_count"]))
@attr.s(auto_attribs=True)
class FederationRateLimitConfig: class FederationRateLimitConfig:
_items_and_default = { window_size: int = 1000
"window_size": 1000, sleep_limit: int = 10
"sleep_limit": 10, sleep_delay: int = 500
"sleep_delay": 500, reject_limit: int = 50
"reject_limit": 50, concurrent: int = 3
"concurrent": 3,
}
def __init__(self, **kwargs):
for i in self._items_and_default.keys():
setattr(self, i, kwargs.get(i) or self._items_and_default[i])
class RatelimitConfig(Config): class RatelimitConfig(Config):
@ -69,11 +66,15 @@ class RatelimitConfig(Config):
else: else:
self.rc_federation = FederationRateLimitConfig( self.rc_federation = FederationRateLimitConfig(
**{ **{
k: v
for k, v in {
"window_size": config.get("federation_rc_window_size"), "window_size": config.get("federation_rc_window_size"),
"sleep_limit": config.get("federation_rc_sleep_limit"), "sleep_limit": config.get("federation_rc_sleep_limit"),
"sleep_delay": config.get("federation_rc_sleep_delay"), "sleep_delay": config.get("federation_rc_sleep_delay"),
"reject_limit": config.get("federation_rc_reject_limit"), "reject_limit": config.get("federation_rc_reject_limit"),
"concurrent": config.get("federation_rc_concurrent"), "concurrent": config.get("federation_rc_concurrent"),
}.items()
if v is not None
} }
) )

View file

@ -22,6 +22,7 @@ from prometheus_client import Counter
from typing_extensions import Literal from typing_extensions import Literal
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.interfaces import IDelayedCall
import synapse.metrics import synapse.metrics
from synapse.api.presence import UserPresenceState from synapse.api.presence import UserPresenceState
@ -284,7 +285,9 @@ class FederationSender(AbstractFederationSender):
) )
# wake up destinations that have outstanding PDUs to be caught up # wake up destinations that have outstanding PDUs to be caught up
self._catchup_after_startup_timer = self.clock.call_later( self._catchup_after_startup_timer: Optional[
IDelayedCall
] = self.clock.call_later(
CATCH_UP_STARTUP_DELAY_SEC, CATCH_UP_STARTUP_DELAY_SEC,
run_as_background_process, run_as_background_process,
"wake_destinations_needing_catchup", "wake_destinations_needing_catchup",
@ -406,7 +409,7 @@ class FederationSender(AbstractFederationSender):
now = self.clock.time_msec() now = self.clock.time_msec()
ts = await self.store.get_received_ts(event.event_id) ts = await self.store.get_received_ts(event.event_id)
assert ts is not None
synapse.metrics.event_processing_lag_by_event.labels( synapse.metrics.event_processing_lag_by_event.labels(
"federation_sender" "federation_sender"
).observe((now - ts) / 1000) ).observe((now - ts) / 1000)
@ -435,6 +438,7 @@ class FederationSender(AbstractFederationSender):
if events: if events:
now = self.clock.time_msec() now = self.clock.time_msec()
ts = await self.store.get_received_ts(events[-1].event_id) ts = await self.store.get_received_ts(events[-1].event_id)
assert ts is not None
synapse.metrics.event_processing_lag.labels( synapse.metrics.event_processing_lag.labels(
"federation_sender" "federation_sender"

View file

@ -398,6 +398,7 @@ class AccountValidityHandler:
""" """
now = self.clock.time_msec() now = self.clock.time_msec()
if expiration_ts is None: if expiration_ts is None:
assert self._account_validity_period is not None
expiration_ts = now + self._account_validity_period expiration_ts = now + self._account_validity_period
await self.store.set_account_validity_for_user( await self.store.set_account_validity_for_user(

View file

@ -131,6 +131,8 @@ class ApplicationServicesHandler:
now = self.clock.time_msec() now = self.clock.time_msec()
ts = await self.store.get_received_ts(event.event_id) ts = await self.store.get_received_ts(event.event_id)
assert ts is not None
synapse.metrics.event_processing_lag_by_event.labels( synapse.metrics.event_processing_lag_by_event.labels(
"appservice_sender" "appservice_sender"
).observe((now - ts) / 1000) ).observe((now - ts) / 1000)
@ -166,6 +168,7 @@ class ApplicationServicesHandler:
if events: if events:
now = self.clock.time_msec() now = self.clock.time_msec()
ts = await self.store.get_received_ts(events[-1].event_id) ts = await self.store.get_received_ts(events[-1].event_id)
assert ts is not None
synapse.metrics.event_processing_lag.labels( synapse.metrics.event_processing_lag.labels(
"appservice_sender" "appservice_sender"

View file

@ -28,6 +28,7 @@ from bisect import bisect
from contextlib import contextmanager from contextlib import contextmanager
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any,
Callable, Callable,
Collection, Collection,
Dict, Dict,
@ -615,7 +616,7 @@ class PresenceHandler(BasePresenceHandler):
super().__init__(hs) super().__init__(hs)
self.hs = hs self.hs = hs
self.server_name = hs.hostname self.server_name = hs.hostname
self.wheel_timer = WheelTimer() self.wheel_timer: WheelTimer[str] = WheelTimer()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self._presence_enabled = hs.config.use_presence self._presence_enabled = hs.config.use_presence
@ -924,7 +925,7 @@ class PresenceHandler(BasePresenceHandler):
prev_state = await self.current_state_for_user(user_id) prev_state = await self.current_state_for_user(user_id)
new_fields = {"last_active_ts": self.clock.time_msec()} new_fields: Dict[str, Any] = {"last_active_ts": self.clock.time_msec()}
if prev_state.state == PresenceState.UNAVAILABLE: if prev_state.state == PresenceState.UNAVAILABLE:
new_fields["state"] = PresenceState.ONLINE new_fields["state"] = PresenceState.ONLINE

View file

@ -73,7 +73,7 @@ class FollowerTypingHandler:
self._room_typing: Dict[str, Set[str]] = {} self._room_typing: Dict[str, Set[str]] = {}
self._member_last_federation_poke: Dict[RoomMember, int] = {} self._member_last_federation_poke: Dict[RoomMember, int] = {}
self.wheel_timer = WheelTimer(bucket_size=5000) self.wheel_timer: WheelTimer[RoomMember] = WheelTimer(bucket_size=5000)
self._latest_room_serial = 0 self._latest_room_serial = 0
self.clock.looping_call(self._handle_timeouts, 5000) self.clock.looping_call(self._handle_timeouts, 5000)

View file

@ -330,11 +330,11 @@ class UsernameAvailabilityRestServlet(RestServlet):
# Artificially delay requests if rate > sleep_limit/window_size # Artificially delay requests if rate > sleep_limit/window_size
sleep_limit=1, sleep_limit=1,
# Amount of artificial delay to apply # Amount of artificial delay to apply
sleep_msec=1000, sleep_delay=1000,
# Error with 429 if more than reject_limit requests are queued # Error with 429 if more than reject_limit requests are queued
reject_limit=1, reject_limit=1,
# Allow 1 request at a time # Allow 1 request at a time
concurrent_requests=1, concurrent=1,
), ),
) )
@ -763,7 +763,10 @@ class RegisterRestServlet(RestServlet):
Returns: Returns:
dictionary for response from /register dictionary for response from /register
""" """
result = {"user_id": user_id, "home_server": self.hs.hostname} result: JsonDict = {
"user_id": user_id,
"home_server": self.hs.hostname,
}
if not params.get("inhibit_login", False): if not params.get("inhibit_login", False):
device_id = params.get("device_id") device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name") initial_display_name = params.get("initial_device_display_name")
@ -814,7 +817,7 @@ class RegisterRestServlet(RestServlet):
user_id, device_id, initial_display_name, is_guest=True user_id, device_id, initial_display_name, is_guest=True
) )
result = { result: JsonDict = {
"user_id": user_id, "user_id": user_id,
"device_id": device_id, "device_id": device_id,
"access_token": access_token, "access_token": access_token,

View file

@ -52,7 +52,7 @@ class NewUserConsentResource(DirectServeHtmlResource):
yield hs.config.sso.sso_template_dir yield hs.config.sso.sso_template_dir
yield hs.config.sso.default_template_dir yield hs.config.sso.default_template_dir
self._jinja_env = build_jinja_env(template_search_dirs(), hs.config) self._jinja_env = build_jinja_env(list(template_search_dirs()), hs.config)
async def _async_render_GET(self, request: Request) -> None: async def _async_render_GET(self, request: Request) -> None:
try: try:

View file

@ -80,7 +80,7 @@ class AccountDetailsResource(DirectServeHtmlResource):
yield hs.config.sso.sso_template_dir yield hs.config.sso.sso_template_dir
yield hs.config.sso.default_template_dir yield hs.config.sso.default_template_dir
self._jinja_env = build_jinja_env(template_search_dirs(), hs.config) self._jinja_env = build_jinja_env(list(template_search_dirs()), hs.config)
async def _async_render_GET(self, request: Request) -> None: async def _async_render_GET(self, request: Request) -> None:
try: try:

View file

@ -1091,6 +1091,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
delta equal to 10% of the validity period. delta equal to 10% of the validity period.
""" """
now_ms = self._clock.time_msec() now_ms = self._clock.time_msec()
assert self._account_validity_period is not None
expiration_ts = now_ms + self._account_validity_period expiration_ts = now_ms + self._account_validity_period
if use_delta: if use_delta:

View file

@ -38,6 +38,7 @@ from twisted.internet.interfaces import (
IReactorCore, IReactorCore,
IReactorPluggableNameResolver, IReactorPluggableNameResolver,
IReactorTCP, IReactorTCP,
IReactorThreads,
IReactorTime, IReactorTime,
) )
@ -63,7 +64,12 @@ JsonDict = Dict[str, Any]
# Note that this seems to require inheriting *directly* from Interface in order # Note that this seems to require inheriting *directly* from Interface in order
# for mypy-zope to realize it is an interface. # for mypy-zope to realize it is an interface.
class ISynapseReactor( class ISynapseReactor(
IReactorTCP, IReactorPluggableNameResolver, IReactorTime, IReactorCore, Interface IReactorTCP,
IReactorPluggableNameResolver,
IReactorTime,
IReactorCore,
IReactorThreads,
Interface,
): ):
"""The interfaces necessary for Synapse to function.""" """The interfaces necessary for Synapse to function."""

View file

@ -15,27 +15,35 @@
import json import json
import logging import logging
import re import re
from typing import Pattern import typing
from typing import Any, Callable, Dict, Generator, Pattern
import attr import attr
from frozendict import frozendict from frozendict import frozendict
from twisted.internet import defer, task from twisted.internet import defer, task
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IDelayedCall, IReactorTime
from twisted.internet.task import LoopingCall
from twisted.python.failure import Failure
from synapse.logging import context from synapse.logging import context
if typing.TYPE_CHECKING:
pass
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_WILDCARD_RUN = re.compile(r"([\?\*]+)") _WILDCARD_RUN = re.compile(r"([\?\*]+)")
def _reject_invalid_json(val): def _reject_invalid_json(val: Any) -> None:
"""Do not allow Infinity, -Infinity, or NaN values in JSON.""" """Do not allow Infinity, -Infinity, or NaN values in JSON."""
raise ValueError("Invalid JSON value: '%s'" % val) raise ValueError("Invalid JSON value: '%s'" % val)
def _handle_frozendict(obj): def _handle_frozendict(obj: Any) -> Dict[Any, Any]:
"""Helper for json_encoder. Makes frozendicts serializable by returning """Helper for json_encoder. Makes frozendicts serializable by returning
the underlying dict the underlying dict
""" """
@ -60,10 +68,10 @@ json_encoder = json.JSONEncoder(
json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json) json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
def unwrapFirstError(failure): def unwrapFirstError(failure: Failure) -> Failure:
# defer.gatherResults and DeferredLists wrap failures. # defer.gatherResults and DeferredLists wrap failures.
failure.trap(defer.FirstError) failure.trap(defer.FirstError)
return failure.value.subFailure return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations
@attr.s(slots=True) @attr.s(slots=True)
@ -75,25 +83,25 @@ class Clock:
reactor: The Twisted reactor to use. reactor: The Twisted reactor to use.
""" """
_reactor = attr.ib() _reactor: IReactorTime = attr.ib()
@defer.inlineCallbacks @defer.inlineCallbacks # type: ignore[arg-type] # Issue in Twisted's type annotations
def sleep(self, seconds): def sleep(self, seconds: float) -> "Generator[Deferred[float], Any, Any]":
d = defer.Deferred() d: defer.Deferred[float] = defer.Deferred()
with context.PreserveLoggingContext(): with context.PreserveLoggingContext():
self._reactor.callLater(seconds, d.callback, seconds) self._reactor.callLater(seconds, d.callback, seconds)
res = yield d res = yield d
return res return res
def time(self): def time(self) -> float:
"""Returns the current system time in seconds since epoch.""" """Returns the current system time in seconds since epoch."""
return self._reactor.seconds() return self._reactor.seconds()
def time_msec(self): def time_msec(self) -> int:
"""Returns the current system time in milliseconds since epoch.""" """Returns the current system time in milliseconds since epoch."""
return int(self.time() * 1000) return int(self.time() * 1000)
def looping_call(self, f, msec, *args, **kwargs): def looping_call(self, f: Callable, msec: float, *args, **kwargs) -> LoopingCall:
"""Call a function repeatedly. """Call a function repeatedly.
Waits `msec` initially before calling `f` for the first time. Waits `msec` initially before calling `f` for the first time.
@ -102,8 +110,8 @@ class Clock:
other than trivial, you probably want to wrap it in run_as_background_process. other than trivial, you probably want to wrap it in run_as_background_process.
Args: Args:
f(function): The function to call repeatedly. f: The function to call repeatedly.
msec(float): How long to wait between calls in milliseconds. msec: How long to wait between calls in milliseconds.
*args: Postional arguments to pass to function. *args: Postional arguments to pass to function.
**kwargs: Key arguments to pass to function. **kwargs: Key arguments to pass to function.
""" """
@ -113,7 +121,7 @@ class Clock:
d.addErrback(log_failure, "Looping call died", consumeErrors=False) d.addErrback(log_failure, "Looping call died", consumeErrors=False)
return call return call
def call_later(self, delay, callback, *args, **kwargs): def call_later(self, delay, callback, *args, **kwargs) -> IDelayedCall:
"""Call something later """Call something later
Note that the function will be called with no logcontext, so if it is anything Note that the function will be called with no logcontext, so if it is anything
@ -133,7 +141,7 @@ class Clock:
with context.PreserveLoggingContext(): with context.PreserveLoggingContext():
return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs) return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs)
def cancel_call_later(self, timer, ignore_errs=False): def cancel_call_later(self, timer: IDelayedCall, ignore_errs: bool = False) -> None:
try: try:
timer.cancel() timer.cancel()
except Exception: except Exception:

View file

@ -37,6 +37,7 @@ import attr
from typing_extensions import ContextManager from typing_extensions import ContextManager
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.base import ReactorBase
from twisted.internet.defer import CancelledError from twisted.internet.defer import CancelledError
from twisted.internet.interfaces import IReactorTime from twisted.internet.interfaces import IReactorTime
from twisted.python import failure from twisted.python import failure
@ -268,6 +269,7 @@ class Linearizer:
if not clock: if not clock:
from twisted.internet import reactor from twisted.internet import reactor
assert isinstance(reactor, ReactorBase)
clock = Clock(reactor) clock = Clock(reactor)
self._clock = clock self._clock = clock
self.max_count = max_count self.max_count = max_count
@ -411,7 +413,7 @@ class ReadWriteLock:
# writers and readers have been resolved. The new writer replaces the latest # writers and readers have been resolved. The new writer replaces the latest
# writer. # writer.
def __init__(self): def __init__(self) -> None:
# Latest readers queued # Latest readers queued
self.key_to_current_readers: Dict[str, Set[defer.Deferred]] = {} self.key_to_current_readers: Dict[str, Set[defer.Deferred]] = {}
@ -503,7 +505,7 @@ def timeout_deferred(
timed_out = [False] timed_out = [False]
def time_it_out(): def time_it_out() -> None:
timed_out[0] = True timed_out[0] = True
try: try:
@ -550,19 +552,21 @@ def timeout_deferred(
return new_d return new_d
# This class can't be generic because it uses slots with attrs.
# See: https://github.com/python-attrs/attrs/issues/313
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)
class DoneAwaitable: class DoneAwaitable: # should be: Generic[R]
"""Simple awaitable that returns the provided value.""" """Simple awaitable that returns the provided value."""
value = attr.ib() value = attr.ib(type=Any) # should be: R
def __await__(self): def __await__(self):
return self return self
def __iter__(self): def __iter__(self) -> "DoneAwaitable":
return self return self
def __next__(self): def __next__(self) -> None:
raise StopIteration(self.value) raise StopIteration(self.value)

View file

@ -122,7 +122,7 @@ class BatchingQueue(Generic[V, R]):
# First we create a defer and add it and the value to the list of # First we create a defer and add it and the value to the list of
# pending items. # pending items.
d = defer.Deferred() d: defer.Deferred[R] = defer.Deferred()
self._next_values.setdefault(key, []).append((value, d)) self._next_values.setdefault(key, []).append((value, d))
# If we're not currently processing the key fire off a background # If we're not currently processing the key fire off a background

View file

@ -64,32 +64,32 @@ class CacheMetric:
evicted_size = attr.ib(default=0) evicted_size = attr.ib(default=0)
memory_usage = attr.ib(default=None) memory_usage = attr.ib(default=None)
def inc_hits(self): def inc_hits(self) -> None:
self.hits += 1 self.hits += 1
def inc_misses(self): def inc_misses(self) -> None:
self.misses += 1 self.misses += 1
def inc_evictions(self, size=1): def inc_evictions(self, size: int = 1) -> None:
self.evicted_size += size self.evicted_size += size
def inc_memory_usage(self, memory: int): def inc_memory_usage(self, memory: int) -> None:
if self.memory_usage is None: if self.memory_usage is None:
self.memory_usage = 0 self.memory_usage = 0
self.memory_usage += memory self.memory_usage += memory
def dec_memory_usage(self, memory: int): def dec_memory_usage(self, memory: int) -> None:
self.memory_usage -= memory self.memory_usage -= memory
def clear_memory_usage(self): def clear_memory_usage(self) -> None:
if self.memory_usage is not None: if self.memory_usage is not None:
self.memory_usage = 0 self.memory_usage = 0
def describe(self): def describe(self):
return [] return []
def collect(self): def collect(self) -> None:
try: try:
if self._cache_type == "response_cache": if self._cache_type == "response_cache":
response_cache_size.labels(self._cache_name).set(len(self._cache)) response_cache_size.labels(self._cache_name).set(len(self._cache))

View file

@ -93,7 +93,7 @@ class DeferredCache(Generic[KT, VT]):
TreeCache, "MutableMapping[KT, CacheEntry]" TreeCache, "MutableMapping[KT, CacheEntry]"
] = cache_type() ] = cache_type()
def metrics_cb(): def metrics_cb() -> None:
cache_pending_metric.labels(name).set(len(self._pending_deferred_cache)) cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
# cache is used for completed results and maps to the result itself, rather than # cache is used for completed results and maps to the result itself, rather than
@ -113,7 +113,7 @@ class DeferredCache(Generic[KT, VT]):
def max_entries(self): def max_entries(self):
return self.cache.max_size return self.cache.max_size
def check_thread(self): def check_thread(self) -> None:
expected_thread = self.thread expected_thread = self.thread
if expected_thread is None: if expected_thread is None:
self.thread = threading.current_thread() self.thread = threading.current_thread()
@ -235,7 +235,7 @@ class DeferredCache(Generic[KT, VT]):
self._pending_deferred_cache[key] = entry self._pending_deferred_cache[key] = entry
def compare_and_pop(): def compare_and_pop() -> bool:
"""Check if our entry is still the one in _pending_deferred_cache, and """Check if our entry is still the one in _pending_deferred_cache, and
if so, pop it. if so, pop it.
@ -256,7 +256,7 @@ class DeferredCache(Generic[KT, VT]):
return False return False
def cb(result): def cb(result) -> None:
if compare_and_pop(): if compare_and_pop():
self.cache.set(key, result, entry.callbacks) self.cache.set(key, result, entry.callbacks)
else: else:
@ -268,7 +268,7 @@ class DeferredCache(Generic[KT, VT]):
# not have been. Either way, let's double-check now. # not have been. Either way, let's double-check now.
entry.invalidate() entry.invalidate()
def eb(_fail): def eb(_fail) -> None:
compare_and_pop() compare_and_pop()
entry.invalidate() entry.invalidate()
@ -314,7 +314,7 @@ class DeferredCache(Generic[KT, VT]):
for entry in iterate_tree_cache_entry(entry): for entry in iterate_tree_cache_entry(entry):
entry.invalidate() entry.invalidate()
def invalidate_all(self): def invalidate_all(self) -> None:
self.check_thread() self.check_thread()
self.cache.clear() self.cache.clear()
for entry in self._pending_deferred_cache.values(): for entry in self._pending_deferred_cache.values():
@ -332,7 +332,7 @@ class CacheEntry:
self.callbacks = set(callbacks) self.callbacks = set(callbacks)
self.invalidated = False self.invalidated = False
def invalidate(self): def invalidate(self) -> None:
if not self.invalidated: if not self.invalidated:
self.invalidated = True self.invalidated = True
for callback in self.callbacks: for callback in self.callbacks:

View file

@ -27,10 +27,14 @@ logger = logging.getLogger(__name__)
KT = TypeVar("KT") KT = TypeVar("KT")
# The type of the dictionary keys. # The type of the dictionary keys.
DKT = TypeVar("DKT") DKT = TypeVar("DKT")
# The type of the dictionary values.
DV = TypeVar("DV")
# This class can't be generic because it uses slots with attrs.
# See: https://github.com/python-attrs/attrs/issues/313
@attr.s(slots=True) @attr.s(slots=True)
class DictionaryEntry: class DictionaryEntry: # should be: Generic[DKT, DV].
"""Returned when getting an entry from the cache """Returned when getting an entry from the cache
Attributes: Attributes:
@ -43,10 +47,10 @@ class DictionaryEntry:
""" """
full = attr.ib(type=bool) full = attr.ib(type=bool)
known_absent = attr.ib() known_absent = attr.ib(type=Set[Any]) # should be: Set[DKT]
value = attr.ib() value = attr.ib(type=Dict[Any, Any]) # should be: Dict[DKT, DV]
def __len__(self): def __len__(self) -> int:
return len(self.value) return len(self.value)
@ -56,7 +60,7 @@ class _Sentinel(enum.Enum):
sentinel = object() sentinel = object()
class DictionaryCache(Generic[KT, DKT]): class DictionaryCache(Generic[KT, DKT, DV]):
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e. """Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
fetching a subset of dictionary keys for a particular key. fetching a subset of dictionary keys for a particular key.
""" """
@ -87,7 +91,7 @@ class DictionaryCache(Generic[KT, DKT]):
Args: Args:
key key
dict_key: If given a set of keys then return only those keys dict_keys: If given a set of keys then return only those keys
that exist in the cache. that exist in the cache.
Returns: Returns:
@ -125,7 +129,7 @@ class DictionaryCache(Generic[KT, DKT]):
self, self,
sequence: int, sequence: int,
key: KT, key: KT,
value: Dict[DKT, Any], value: Dict[DKT, DV],
fetched_keys: Optional[Set[DKT]] = None, fetched_keys: Optional[Set[DKT]] = None,
) -> None: ) -> None:
"""Updates the entry in the cache """Updates the entry in the cache
@ -151,15 +155,15 @@ class DictionaryCache(Generic[KT, DKT]):
self._update_or_insert(key, value, fetched_keys) self._update_or_insert(key, value, fetched_keys)
def _update_or_insert( def _update_or_insert(
self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT] self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]
) -> None: ) -> None:
# We pop and reinsert as we need to tell the cache the size may have # We pop and reinsert as we need to tell the cache the size may have
# changed # changed
entry = self.cache.pop(key, DictionaryEntry(False, set(), {})) entry: DictionaryEntry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
entry.value.update(value) entry.value.update(value)
entry.known_absent.update(known_absent) entry.known_absent.update(known_absent)
self.cache[key] = entry self.cache[key] = entry
def _insert(self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]) -> None: def _insert(self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]) -> None:
self.cache[key] = DictionaryEntry(True, known_absent, value) self.cache[key] = DictionaryEntry(True, known_absent, value)

View file

@ -35,6 +35,7 @@ from typing import (
from typing_extensions import Literal from typing_extensions import Literal
from twisted.internet import reactor from twisted.internet import reactor
from twisted.internet.interfaces import IReactorTime
from synapse.config import cache as cache_config from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
@ -341,7 +342,7 @@ class LruCache(Generic[KT, VT]):
# Default `clock` to something sensible. Note that we rename it to # Default `clock` to something sensible. Note that we rename it to
# `real_clock` so that mypy doesn't think its still `Optional`. # `real_clock` so that mypy doesn't think its still `Optional`.
if clock is None: if clock is None:
real_clock = Clock(reactor) real_clock = Clock(cast(IReactorTime, reactor))
else: else:
real_clock = clock real_clock = clock
@ -384,7 +385,7 @@ class LruCache(Generic[KT, VT]):
lock = threading.Lock() lock = threading.Lock()
def evict(): def evict() -> None:
while cache_len() > self.max_size: while cache_len() > self.max_size:
# Get the last node in the list (i.e. the oldest node). # Get the last node in the list (i.e. the oldest node).
todelete = list_root.prev_node todelete = list_root.prev_node

View file

@ -195,7 +195,7 @@ class StreamChangeCache:
for entity in r: for entity in r:
del self._entity_to_key[entity] del self._entity_to_key[entity]
def _evict(self): def _evict(self) -> None:
while len(self._cache) > self._max_size: while len(self._cache) > self._max_size:
k, r = self._cache.popitem(0) k, r = self._cache.popitem(0)
self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)

View file

@ -35,17 +35,17 @@ class TreeCache:
root = {key_1: {key_2: _value}} root = {key_1: {key_2: _value}}
""" """
def __init__(self): def __init__(self) -> None:
self.size = 0 self.size: int = 0
self.root = TreeCacheNode() self.root = TreeCacheNode()
def __setitem__(self, key, value): def __setitem__(self, key, value) -> None:
return self.set(key, value) self.set(key, value)
def __contains__(self, key): def __contains__(self, key) -> bool:
return self.get(key, SENTINEL) is not SENTINEL return self.get(key, SENTINEL) is not SENTINEL
def set(self, key, value): def set(self, key, value) -> None:
if isinstance(value, TreeCacheNode): if isinstance(value, TreeCacheNode):
# this would mean we couldn't tell where our tree ended and the value # this would mean we couldn't tell where our tree ended and the value
# started. # started.
@ -73,7 +73,7 @@ class TreeCache:
return default return default
return node.get(key[-1], default) return node.get(key[-1], default)
def clear(self): def clear(self) -> None:
self.size = 0 self.size = 0
self.root = TreeCacheNode() self.root = TreeCacheNode()
@ -128,7 +128,7 @@ class TreeCache:
def values(self): def values(self):
return iterate_tree_cache_entry(self.root) return iterate_tree_cache_entry(self.root)
def __len__(self): def __len__(self) -> int:
return self.size return self.size

View file

@ -126,7 +126,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
signal.signal(signal.SIGTERM, sigterm) signal.signal(signal.SIGTERM, sigterm)
# Cleanup pid file at exit. # Cleanup pid file at exit.
def exit(): def exit() -> None:
logger.warning("Stopping daemon.") logger.warning("Stopping daemon.")
os.remove(pid_file) os.remove(pid_file)
sys.exit(0) sys.exit(0)

View file

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Callable, Dict, List
from twisted.internet import defer from twisted.internet import defer
@ -37,11 +38,11 @@ class Distributor:
model will do for today. model will do for today.
""" """
def __init__(self): def __init__(self) -> None:
self.signals = {} self.signals: Dict[str, Signal] = {}
self.pre_registration = {} self.pre_registration: Dict[str, List[Callable]] = {}
def declare(self, name): def declare(self, name: str) -> None:
if name in self.signals: if name in self.signals:
raise KeyError("%r already has a signal named %s" % (self, name)) raise KeyError("%r already has a signal named %s" % (self, name))
@ -52,7 +53,7 @@ class Distributor:
for observer in self.pre_registration[name]: for observer in self.pre_registration[name]:
signal.observe(observer) signal.observe(observer)
def observe(self, name, observer): def observe(self, name: str, observer: Callable) -> None:
if name in self.signals: if name in self.signals:
self.signals[name].observe(observer) self.signals[name].observe(observer)
else: else:
@ -62,7 +63,7 @@ class Distributor:
self.pre_registration[name] = [] self.pre_registration[name] = []
self.pre_registration[name].append(observer) self.pre_registration[name].append(observer)
def fire(self, name, *args, **kwargs): def fire(self, name: str, *args, **kwargs) -> None:
"""Dispatches the given signal to the registered observers. """Dispatches the given signal to the registered observers.
Runs the observers as a background process. Does not return a deferred. Runs the observers as a background process. Does not return a deferred.
@ -83,18 +84,18 @@ class Signal:
method into all of the observers. method into all of the observers.
""" """
def __init__(self, name): def __init__(self, name: str):
self.name = name self.name: str = name
self.observers = [] self.observers: List[Callable] = []
def observe(self, observer): def observe(self, observer: Callable) -> None:
"""Adds a new callable to the observer list which will be invoked by """Adds a new callable to the observer list which will be invoked by
the 'fire' method. the 'fire' method.
Each observer callable may return a Deferred.""" Each observer callable may return a Deferred."""
self.observers.append(observer) self.observers.append(observer)
def fire(self, *args, **kwargs): def fire(self, *args, **kwargs) -> "defer.Deferred[List[Any]]":
"""Invokes every callable in the observer list, passing in the args and """Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is kwargs. Exceptions thrown by observers are logged but ignored. It is
not an error to fire a signal with no observers. not an error to fire a signal with no observers.

View file

@ -13,10 +13,14 @@
# limitations under the License. # limitations under the License.
import queue import queue
from typing import BinaryIO, Optional, Union, cast
from twisted.internet import threads from twisted.internet import threads
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IPullProducer, IPushProducer
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import ISynapseReactor
class BackgroundFileConsumer: class BackgroundFileConsumer:
@ -24,9 +28,9 @@ class BackgroundFileConsumer:
and pull producers and pull producers
Args: Args:
file_obj (file): The file like object to write to. Closed when file_obj: The file like object to write to. Closed when
finished. finished.
reactor (twisted.internet.reactor): the Twisted reactor to use reactor: the Twisted reactor to use
""" """
# For PushProducers pause if we have this many unwritten slices # For PushProducers pause if we have this many unwritten slices
@ -34,13 +38,13 @@ class BackgroundFileConsumer:
# And resume once the size of the queue is less than this # And resume once the size of the queue is less than this
_RESUME_ON_QUEUE_SIZE = 2 _RESUME_ON_QUEUE_SIZE = 2
def __init__(self, file_obj, reactor): def __init__(self, file_obj: BinaryIO, reactor: ISynapseReactor) -> None:
self._file_obj = file_obj self._file_obj: BinaryIO = file_obj
self._reactor = reactor self._reactor: ISynapseReactor = reactor
# Producer we're registered with # Producer we're registered with
self._producer = None self._producer: Optional[Union[IPushProducer, IPullProducer]] = None
# True if PushProducer, false if PullProducer # True if PushProducer, false if PullProducer
self.streaming = False self.streaming = False
@ -51,20 +55,22 @@ class BackgroundFileConsumer:
# Queue of slices of bytes to be written. When producer calls # Queue of slices of bytes to be written. When producer calls
# unregister a final None is sent. # unregister a final None is sent.
self._bytes_queue = queue.Queue() self._bytes_queue: queue.Queue[Optional[bytes]] = queue.Queue()
# Deferred that is resolved when finished writing # Deferred that is resolved when finished writing
self._finished_deferred = None self._finished_deferred: Optional[Deferred[None]] = None
# If the _writer thread throws an exception it gets stored here. # If the _writer thread throws an exception it gets stored here.
self._write_exception = None self._write_exception: Optional[Exception] = None
def registerProducer(self, producer, streaming): def registerProducer(
self, producer: Union[IPushProducer, IPullProducer], streaming: bool
) -> None:
"""Part of IConsumer interface """Part of IConsumer interface
Args: Args:
producer (IProducer) producer
streaming (bool): True if push based producer, False if pull streaming: True if push based producer, False if pull
based. based.
""" """
if self._producer: if self._producer:
@ -81,29 +87,33 @@ class BackgroundFileConsumer:
if not streaming: if not streaming:
self._producer.resumeProducing() self._producer.resumeProducing()
def unregisterProducer(self): def unregisterProducer(self) -> None:
"""Part of IProducer interface""" """Part of IProducer interface"""
self._producer = None self._producer = None
assert self._finished_deferred is not None
if not self._finished_deferred.called: if not self._finished_deferred.called:
self._bytes_queue.put_nowait(None) self._bytes_queue.put_nowait(None)
def write(self, bytes): def write(self, write_bytes: bytes) -> None:
"""Part of IProducer interface""" """Part of IProducer interface"""
if self._write_exception: if self._write_exception:
raise self._write_exception raise self._write_exception
assert self._finished_deferred is not None
if self._finished_deferred.called: if self._finished_deferred.called:
raise Exception("consumer has closed") raise Exception("consumer has closed")
self._bytes_queue.put_nowait(bytes) self._bytes_queue.put_nowait(write_bytes)
# If this is a PushProducer and the queue is getting behind # If this is a PushProducer and the queue is getting behind
# then we pause the producer. # then we pause the producer.
if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE: if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE:
self._paused_producer = True self._paused_producer = True
self._producer.pauseProducing() assert self._producer is not None
# cast safe because `streaming` means this is an IPushProducer
cast(IPushProducer, self._producer).pauseProducing()
def _writer(self): def _writer(self) -> None:
"""This is run in a background thread to write to the file.""" """This is run in a background thread to write to the file."""
try: try:
while self._producer or not self._bytes_queue.empty(): while self._producer or not self._bytes_queue.empty():
@ -130,11 +140,11 @@ class BackgroundFileConsumer:
finally: finally:
self._file_obj.close() self._file_obj.close()
def wait(self): def wait(self) -> "Deferred[None]":
"""Returns a deferred that resolves when finished writing to file""" """Returns a deferred that resolves when finished writing to file"""
return make_deferred_yieldable(self._finished_deferred) return make_deferred_yieldable(self._finished_deferred)
def _resume_paused_producer(self): def _resume_paused_producer(self) -> None:
"""Gets called if we should resume producing after being paused""" """Gets called if we should resume producing after being paused"""
if self._paused_producer and self._producer: if self._paused_producer and self._producer:
self._paused_producer = False self._paused_producer = False

View file

@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any
from frozendict import frozendict from frozendict import frozendict
def freeze(o): def freeze(o: Any) -> Any:
if isinstance(o, dict): if isinstance(o, dict):
return frozendict({k: freeze(v) for k, v in o.items()}) return frozendict({k: freeze(v) for k, v in o.items()})
@ -33,7 +34,7 @@ def freeze(o):
return o return o
def unfreeze(o): def unfreeze(o: Any) -> Any:
if isinstance(o, (dict, frozendict)): if isinstance(o, (dict, frozendict)):
return {k: unfreeze(v) for k, v in o.items()} return {k: unfreeze(v) for k, v in o.items()}

View file

@ -13,42 +13,43 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict
from twisted.web.resource import NoResource from twisted.web.resource import NoResource, Resource
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def create_resource_tree(desired_tree, root_resource): def create_resource_tree(
desired_tree: Dict[str, Resource], root_resource: Resource
) -> Resource:
"""Create the resource tree for this homeserver. """Create the resource tree for this homeserver.
This in unduly complicated because Twisted does not support putting This in unduly complicated because Twisted does not support putting
child resources more than 1 level deep at a time. child resources more than 1 level deep at a time.
Args: Args:
web_client (bool): True to enable the web client. desired_tree: Dict from desired paths to desired resources.
root_resource (twisted.web.resource.Resource): The root root_resource: The root resource to add the tree to.
resource to add the tree to.
Returns: Returns:
twisted.web.resource.Resource: the ``root_resource`` with a tree of The ``root_resource`` with a tree of child resources added to it.
child resources added to it.
""" """
# ideally we'd just use getChild and putChild but getChild doesn't work # ideally we'd just use getChild and putChild but getChild doesn't work
# unless you give it a Request object IN ADDITION to the name :/ So # unless you give it a Request object IN ADDITION to the name :/ So
# instead, we'll store a copy of this mapping so we can actually add # instead, we'll store a copy of this mapping so we can actually add
# extra resources to existing nodes. See self._resource_id for the key. # extra resources to existing nodes. See self._resource_id for the key.
resource_mappings = {} resource_mappings: Dict[str, Resource] = {}
for full_path, res in desired_tree.items(): for full_path_str, res in desired_tree.items():
# twisted requires all resources to be bytes # twisted requires all resources to be bytes
full_path = full_path.encode("utf-8") full_path = full_path_str.encode("utf-8")
logger.info("Attaching %s to path %s", res, full_path) logger.info("Attaching %s to path %s", res, full_path)
last_resource = root_resource last_resource = root_resource
for path_seg in full_path.split(b"/")[1:-1]: for path_seg in full_path.split(b"/")[1:-1]:
if path_seg not in last_resource.listNames(): if path_seg not in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource" # resource doesn't exist, so make a "dummy resource"
child_resource = NoResource() child_resource: Resource = NoResource()
last_resource.putChild(path_seg, child_resource) last_resource.putChild(path_seg, child_resource)
res_id = _resource_id(last_resource, path_seg) res_id = _resource_id(last_resource, path_seg)
resource_mappings[res_id] = child_resource resource_mappings[res_id] = child_resource
@ -83,7 +84,7 @@ def create_resource_tree(desired_tree, root_resource):
return root_resource return root_resource
def _resource_id(resource, path_seg): def _resource_id(resource: Resource, path_seg: bytes) -> str:
"""Construct an arbitrary resource ID so you can retrieve the mapping """Construct an arbitrary resource ID so you can retrieve the mapping
later. later.
@ -96,4 +97,4 @@ def _resource_id(resource, path_seg):
Returns: Returns:
str: A unique string which can be a key to the child Resource. str: A unique string which can be a key to the child Resource.
""" """
return "%s-%s" % (resource, path_seg) return "%s-%r" % (resource, path_seg)

View file

@ -74,7 +74,7 @@ class ListNode(Generic[P]):
new_node._refs_insert_after(node) new_node._refs_insert_after(node)
return new_node return new_node
def remove_from_list(self): def remove_from_list(self) -> None:
"""Remove this node from the list.""" """Remove this node from the list."""
with self._LOCK: with self._LOCK:
self._refs_remove_node_from_list() self._refs_remove_node_from_list()
@ -84,7 +84,7 @@ class ListNode(Generic[P]):
# immediately rather than at the next GC. # immediately rather than at the next GC.
self.cache_entry = None self.cache_entry = None
def move_after(self, node: "ListNode"): def move_after(self, node: "ListNode") -> None:
"""Move this node from its current location in the list to after the """Move this node from its current location in the list to after the
given node. given node.
""" """
@ -103,7 +103,7 @@ class ListNode(Generic[P]):
# Insert self back into the list, after target node # Insert self back into the list, after target node
self._refs_insert_after(node) self._refs_insert_after(node)
def _refs_remove_node_from_list(self): def _refs_remove_node_from_list(self) -> None:
"""Internal method to *just* remove the node from the list, without """Internal method to *just* remove the node from the list, without
e.g. clearing out the cache entry. e.g. clearing out the cache entry.
""" """
@ -122,7 +122,7 @@ class ListNode(Generic[P]):
self.prev_node = None self.prev_node = None
self.next_node = None self.next_node = None
def _refs_insert_after(self, node: "ListNode"): def _refs_insert_after(self, node: "ListNode") -> None:
"""Internal method to insert the node after the given node.""" """Internal method to insert the node after the given node."""
# This method should only be called when we're not already in the list. # This method should only be called when we're not already in the list.

View file

@ -77,7 +77,7 @@ def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> N
should be considered expired. Normally the current time. should be considered expired. Normally the current time.
""" """
def verify_expiry_caveat(caveat: str): def verify_expiry_caveat(caveat: str) -> bool:
time_msec = get_time_ms() time_msec = get_time_ms()
prefix = "time < " prefix = "time < "
if not caveat.startswith(prefix): if not caveat.startswith(prefix):

View file

@ -15,6 +15,7 @@
import inspect import inspect
import sys import sys
import traceback import traceback
from typing import Any, Dict, Optional
from twisted.conch import manhole_ssh from twisted.conch import manhole_ssh
from twisted.conch.insults import insults from twisted.conch.insults import insults
@ -22,6 +23,9 @@ from twisted.conch.manhole import ColoredManhole, ManholeInterpreter
from twisted.conch.ssh.keys import Key from twisted.conch.ssh.keys import Key
from twisted.cred import checkers, portal from twisted.cred import checkers, portal
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.protocol import Factory
from synapse.config.server import ManholeConfig
PUBLIC_KEY = ( PUBLIC_KEY = (
"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDHhGATaW4KhE23+7nrH4jFx3yLq9OjaEs5" "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDHhGATaW4KhE23+7nrH4jFx3yLq9OjaEs5"
@ -61,22 +65,22 @@ EddTrx3TNpr1D5m/f+6mnXWrc8u9y1+GNx9yz889xMjIBTBI9KqaaOs=
-----END RSA PRIVATE KEY-----""" -----END RSA PRIVATE KEY-----"""
def manhole(settings, globals): def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> Factory:
"""Starts a ssh listener with password authentication using """Starts a ssh listener with password authentication using
the given username and password. Clients connecting to the ssh the given username and password. Clients connecting to the ssh
listener will find themselves in a colored python shell with listener will find themselves in a colored python shell with
the supplied globals. the supplied globals.
Args: Args:
username(str): The username ssh clients should auth with. username: The username ssh clients should auth with.
password(str): The password ssh clients should auth with. password: The password ssh clients should auth with.
globals(dict): The variables to expose in the shell. globals: The variables to expose in the shell.
Returns: Returns:
twisted.internet.protocol.Factory: A factory to pass to ``listenTCP`` A factory to pass to ``listenTCP``
""" """
username = settings.username username = settings.username
password = settings.password password = settings.password.encode("ascii")
priv_key = settings.priv_key priv_key = settings.priv_key
if priv_key is None: if priv_key is None:
priv_key = Key.fromString(PRIVATE_KEY) priv_key = Key.fromString(PRIVATE_KEY)
@ -84,19 +88,22 @@ def manhole(settings, globals):
if pub_key is None: if pub_key is None:
pub_key = Key.fromString(PUBLIC_KEY) pub_key = Key.fromString(PUBLIC_KEY)
if not isinstance(password, bytes):
password = password.encode("ascii")
checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(**{username: password}) checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(**{username: password})
rlm = manhole_ssh.TerminalRealm() rlm = manhole_ssh.TerminalRealm()
rlm.chainedProtocolFactory = lambda: insults.ServerProtocol( # mypy ignored here because:
# - can't deduce types of lambdas
# - variable is Type[ServerProtocol], expr is Callable[[], ServerProtocol]
rlm.chainedProtocolFactory = lambda: insults.ServerProtocol( # type: ignore[misc,assignment]
SynapseManhole, dict(globals, __name__="__console__") SynapseManhole, dict(globals, __name__="__console__")
) )
factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
factory.privateKeys[b"ssh-rsa"] = priv_key
factory.publicKeys[b"ssh-rsa"] = pub_key # conch has the wrong type on these dicts (says bytes to bytes,
# should be bytes to Keys judging by how it's used).
factory.privateKeys[b"ssh-rsa"] = priv_key # type: ignore[assignment]
factory.publicKeys[b"ssh-rsa"] = pub_key # type: ignore[assignment]
return factory return factory
@ -104,7 +111,7 @@ def manhole(settings, globals):
class SynapseManhole(ColoredManhole): class SynapseManhole(ColoredManhole):
"""Overrides connectionMade to create our own ManholeInterpreter""" """Overrides connectionMade to create our own ManholeInterpreter"""
def connectionMade(self): def connectionMade(self) -> None:
super().connectionMade() super().connectionMade()
# replace the manhole interpreter with our own impl # replace the manhole interpreter with our own impl
@ -114,13 +121,14 @@ class SynapseManhole(ColoredManhole):
class SynapseManholeInterpreter(ManholeInterpreter): class SynapseManholeInterpreter(ManholeInterpreter):
def showsyntaxerror(self, filename=None): def showsyntaxerror(self, filename: Optional[str] = None) -> None:
"""Display the syntax error that just occurred. """Display the syntax error that just occurred.
Overrides the base implementation, ignoring sys.excepthook. We always want Overrides the base implementation, ignoring sys.excepthook. We always want
any syntax errors to be sent to the terminal, rather than sentry. any syntax errors to be sent to the terminal, rather than sentry.
""" """
type, value, tb = sys.exc_info() type, value, tb = sys.exc_info()
assert value is not None
sys.last_type = type sys.last_type = type
sys.last_value = value sys.last_value = value
sys.last_traceback = tb sys.last_traceback = tb
@ -138,7 +146,7 @@ class SynapseManholeInterpreter(ManholeInterpreter):
lines = traceback.format_exception_only(type, value) lines = traceback.format_exception_only(type, value)
self.write("".join(lines)) self.write("".join(lines))
def showtraceback(self): def showtraceback(self) -> None:
"""Display the exception that just occurred. """Display the exception that just occurred.
Overrides the base implementation, ignoring sys.excepthook. We always want Overrides the base implementation, ignoring sys.excepthook. We always want
@ -146,14 +154,22 @@ class SynapseManholeInterpreter(ManholeInterpreter):
""" """
sys.last_type, sys.last_value, last_tb = ei = sys.exc_info() sys.last_type, sys.last_value, last_tb = ei = sys.exc_info()
sys.last_traceback = last_tb sys.last_traceback = last_tb
assert last_tb is not None
try: try:
# We remove the first stack item because it is our own code. # We remove the first stack item because it is our own code.
lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next) lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next)
self.write("".join(lines)) self.write("".join(lines))
finally: finally:
last_tb = ei = None # On the line below, last_tb and ei appear to be dead.
# It's unclear whether there is a reason behind this line.
# It conceivably could be because an exception raised in this block
# will keep the local frame (containing these local variables) around.
# This was adapted taken from CPython's Lib/code.py; see here:
# https://github.com/python/cpython/blob/4dc4300c686f543d504ab6fa9fe600eaf11bb695/Lib/code.py#L131-L150
last_tb = ei = None # type: ignore
def displayhook(self, obj): def displayhook(self, obj: Any) -> None:
""" """
We override the displayhook so that we automatically convert coroutines We override the displayhook so that we automatically convert coroutines
into Deferreds. (Our superclass' displayhook will take care of the rest, into Deferreds. (Our superclass' displayhook will take care of the rest,

View file

@ -24,7 +24,7 @@ from twisted.python.failure import Failure
_already_patched = False _already_patched = False
def do_patch(): def do_patch() -> None:
""" """
Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
""" """
@ -107,7 +107,7 @@ def do_patch():
_already_patched = True _already_patched = True
def _check_yield_points(f: Callable, changes: List[str]): def _check_yield_points(f: Callable, changes: List[str]) -> Callable:
"""Wraps a generator that is about to be passed to defer.inlineCallbacks """Wraps a generator that is about to be passed to defer.inlineCallbacks
checking that after every yield the log contexts are correct. checking that after every yield the log contexts are correct.

View file

@ -15,33 +15,36 @@
import collections import collections
import contextlib import contextlib
import logging import logging
import typing
from typing import Any, DefaultDict, Iterator, List, Set
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import LimitExceededError from synapse.api.errors import LimitExceededError
from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.logging.context import ( from synapse.logging.context import (
PreserveLoggingContext, PreserveLoggingContext,
make_deferred_yieldable, make_deferred_yieldable,
run_in_background, run_in_background,
) )
from synapse.util import Clock
if typing.TYPE_CHECKING:
from contextlib import _GeneratorContextManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FederationRateLimiter: class FederationRateLimiter:
def __init__(self, clock, config): def __init__(self, clock: Clock, config: FederationRateLimitConfig):
""" def new_limiter() -> "_PerHostRatelimiter":
Args:
clock (Clock)
config (FederationRateLimitConfig)
"""
def new_limiter():
return _PerHostRatelimiter(clock=clock, config=config) return _PerHostRatelimiter(clock=clock, config=config)
self.ratelimiters = collections.defaultdict(new_limiter) self.ratelimiters: DefaultDict[
str, "_PerHostRatelimiter"
] = collections.defaultdict(new_limiter)
def ratelimit(self, host): def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]":
"""Used to ratelimit an incoming request from a given host """Used to ratelimit an incoming request from a given host
Example usage: Example usage:
@ -60,11 +63,11 @@ class FederationRateLimiter:
class _PerHostRatelimiter: class _PerHostRatelimiter:
def __init__(self, clock, config): def __init__(self, clock: Clock, config: FederationRateLimitConfig):
""" """
Args: Args:
clock (Clock) clock
config (FederationRateLimitConfig) config
""" """
self.clock = clock self.clock = clock
@ -75,21 +78,23 @@ class _PerHostRatelimiter:
self.concurrent_requests = config.concurrent self.concurrent_requests = config.concurrent
# request_id objects for requests which have been slept # request_id objects for requests which have been slept
self.sleeping_requests = set() self.sleeping_requests: Set[object] = set()
# map from request_id object to Deferred for requests which are ready # map from request_id object to Deferred for requests which are ready
# for processing but have been queued # for processing but have been queued
self.ready_request_queue = collections.OrderedDict() self.ready_request_queue: collections.OrderedDict[
object, defer.Deferred[None]
] = collections.OrderedDict()
# request id objects for requests which are in progress # request id objects for requests which are in progress
self.current_processing = set() self.current_processing: Set[object] = set()
# times at which we have recently (within the last window_size ms) # times at which we have recently (within the last window_size ms)
# received requests. # received requests.
self.request_times = [] self.request_times: List[int] = []
@contextlib.contextmanager @contextlib.contextmanager
def ratelimit(self): def ratelimit(self) -> "Iterator[defer.Deferred[None]]":
# `contextlib.contextmanager` takes a generator and turns it into a # `contextlib.contextmanager` takes a generator and turns it into a
# context manager. The generator should only yield once with a value # context manager. The generator should only yield once with a value
# to be returned by manager. # to be returned by manager.
@ -102,7 +107,7 @@ class _PerHostRatelimiter:
finally: finally:
self._on_exit(request_id) self._on_exit(request_id)
def _on_enter(self, request_id): def _on_enter(self, request_id: object) -> "defer.Deferred[None]":
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
# remove any entries from request_times which aren't within the window # remove any entries from request_times which aren't within the window
@ -120,9 +125,9 @@ class _PerHostRatelimiter:
self.request_times.append(time_now) self.request_times.append(time_now)
def queue_request(): def queue_request() -> "defer.Deferred[None]":
if len(self.current_processing) >= self.concurrent_requests: if len(self.current_processing) >= self.concurrent_requests:
queue_defer = defer.Deferred() queue_defer: defer.Deferred[None] = defer.Deferred()
self.ready_request_queue[request_id] = queue_defer self.ready_request_queue[request_id] = queue_defer
logger.info( logger.info(
"Ratelimiter: queueing request (queue now %i items)", "Ratelimiter: queueing request (queue now %i items)",
@ -145,7 +150,7 @@ class _PerHostRatelimiter:
self.sleeping_requests.add(request_id) self.sleeping_requests.add(request_id)
def on_wait_finished(_): def on_wait_finished(_: Any) -> "defer.Deferred[None]":
logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id)) logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id))
self.sleeping_requests.discard(request_id) self.sleeping_requests.discard(request_id)
queue_defer = queue_request() queue_defer = queue_request()
@ -155,19 +160,19 @@ class _PerHostRatelimiter:
else: else:
ret_defer = queue_request() ret_defer = queue_request()
def on_start(r): def on_start(r: object) -> object:
logger.debug("Ratelimit [%s]: Processing req", id(request_id)) logger.debug("Ratelimit [%s]: Processing req", id(request_id))
self.current_processing.add(request_id) self.current_processing.add(request_id)
return r return r
def on_err(r): def on_err(r: object) -> object:
# XXX: why is this necessary? this is called before we start # XXX: why is this necessary? this is called before we start
# processing the request so why would the request be in # processing the request so why would the request be in
# current_processing? # current_processing?
self.current_processing.discard(request_id) self.current_processing.discard(request_id)
return r return r
def on_both(r): def on_both(r: object) -> object:
# Ensure that we've properly cleaned up. # Ensure that we've properly cleaned up.
self.sleeping_requests.discard(request_id) self.sleeping_requests.discard(request_id)
self.ready_request_queue.pop(request_id, None) self.ready_request_queue.pop(request_id, None)
@ -177,7 +182,7 @@ class _PerHostRatelimiter:
ret_defer.addBoth(on_both) ret_defer.addBoth(on_both)
return make_deferred_yieldable(ret_defer) return make_deferred_yieldable(ret_defer)
def _on_exit(self, request_id): def _on_exit(self, request_id: object) -> None:
logger.debug("Ratelimit [%s]: Processed req", id(request_id)) logger.debug("Ratelimit [%s]: Processed req", id(request_id))
self.current_processing.discard(request_id) self.current_processing.discard(request_id)
try: try:

View file

@ -13,9 +13,13 @@
# limitations under the License. # limitations under the License.
import logging import logging
import random import random
from types import TracebackType
from typing import Any, Optional, Type
import synapse.logging.context import synapse.logging.context
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
from synapse.storage import DataStore
from synapse.util import Clock
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,17 +34,17 @@ MAX_RETRY_INTERVAL = 2 ** 62
class NotRetryingDestination(Exception): class NotRetryingDestination(Exception):
def __init__(self, retry_last_ts, retry_interval, destination): def __init__(self, retry_last_ts: int, retry_interval: int, destination: str):
"""Raised by the limiter (and federation client) to indicate that we are """Raised by the limiter (and federation client) to indicate that we are
are deliberately not attempting to contact a given server. are deliberately not attempting to contact a given server.
Args: Args:
retry_last_ts (int): the unix ts in milliseconds of our last attempt retry_last_ts: the unix ts in milliseconds of our last attempt
to contact the server. 0 indicates that the last attempt was to contact the server. 0 indicates that the last attempt was
successful or that we've never actually attempted to connect. successful or that we've never actually attempted to connect.
retry_interval (int): the time in milliseconds to wait until the next retry_interval: the time in milliseconds to wait until the next
attempt. attempt.
destination (str): the domain in question destination: the domain in question
""" """
msg = "Not retrying server %s." % (destination,) msg = "Not retrying server %s." % (destination,)
@ -51,7 +55,13 @@ class NotRetryingDestination(Exception):
self.destination = destination self.destination = destination
async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs): async def get_retry_limiter(
destination: str,
clock: Clock,
store: DataStore,
ignore_backoff: bool = False,
**kwargs: Any,
) -> "RetryDestinationLimiter":
"""For a given destination check if we have previously failed to """For a given destination check if we have previously failed to
send a request there and are waiting before retrying the destination. send a request there and are waiting before retrying the destination.
If we are not ready to retry the destination, this will raise a If we are not ready to retry the destination, this will raise a
@ -60,10 +70,10 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k
CodeMessageException with code < 500) CodeMessageException with code < 500)
Args: Args:
destination (str): name of homeserver destination: name of homeserver
clock (synapse.util.clock): timing source clock: timing source
store (synapse.storage.transactions.TransactionStore): datastore store: datastore
ignore_backoff (bool): true to ignore the historical backoff data and ignore_backoff: true to ignore the historical backoff data and
try the request anyway. We will still reset the retry_interval on success. try the request anyway. We will still reset the retry_interval on success.
Example usage: Example usage:
@ -114,13 +124,13 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k
class RetryDestinationLimiter: class RetryDestinationLimiter:
def __init__( def __init__(
self, self,
destination, destination: str,
clock, clock: Clock,
store, store: DataStore,
failure_ts, failure_ts: Optional[int],
retry_interval, retry_interval: int,
backoff_on_404=False, backoff_on_404: bool = False,
backoff_on_failure=True, backoff_on_failure: bool = True,
): ):
"""Marks the destination as "down" if an exception is thrown in the """Marks the destination as "down" if an exception is thrown in the
context, except for CodeMessageException with code < 500. context, except for CodeMessageException with code < 500.
@ -128,17 +138,17 @@ class RetryDestinationLimiter:
If no exception is raised, marks the destination as "up". If no exception is raised, marks the destination as "up".
Args: Args:
destination (str) destination
clock (Clock) clock
store (DataStore) store
failure_ts (int|None): when this destination started failing (in ms since failure_ts: when this destination started failing (in ms since
the epoch), or zero if the last request was successful the epoch), or zero if the last request was successful
retry_interval (int): The next retry interval taken from the retry_interval: The next retry interval taken from the
database in milliseconds, or zero if the last request was database in milliseconds, or zero if the last request was
successful. successful.
backoff_on_404 (bool): Back off if we get a 404 backoff_on_404: Back off if we get a 404
backoff_on_failure (bool): set to False if we should not increase the backoff_on_failure: set to False if we should not increase the
retry interval on a failure. retry interval on a failure.
""" """
self.clock = clock self.clock = clock
@ -150,10 +160,15 @@ class RetryDestinationLimiter:
self.backoff_on_404 = backoff_on_404 self.backoff_on_404 = backoff_on_404
self.backoff_on_failure = backoff_on_failure self.backoff_on_failure = backoff_on_failure
def __enter__(self): def __enter__(self) -> None:
pass pass
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
valid_err_code = False valid_err_code = False
if exc_type is None: if exc_type is None:
valid_err_code = True valid_err_code = True
@ -161,7 +176,7 @@ class RetryDestinationLimiter:
# avoid treating exceptions which don't derive from Exception as # avoid treating exceptions which don't derive from Exception as
# failures; this is mostly so as not to catch defer._DefGen. # failures; this is mostly so as not to catch defer._DefGen.
valid_err_code = True valid_err_code = True
elif issubclass(exc_type, CodeMessageException): elif isinstance(exc_val, CodeMessageException):
# Some error codes are perfectly fine for some APIs, whereas other # Some error codes are perfectly fine for some APIs, whereas other
# APIs may expect to never received e.g. a 404. It's important to # APIs may expect to never received e.g. a 404. It's important to
# handle 404 as some remote servers will return a 404 when the HS # handle 404 as some remote servers will return a 404 when the HS
@ -216,7 +231,7 @@ class RetryDestinationLimiter:
if self.failure_ts is None: if self.failure_ts is None:
self.failure_ts = retry_last_ts self.failure_ts = retry_last_ts
async def store_retry_timings(): async def store_retry_timings() -> None:
try: try:
await self.store.set_destination_retry_timings( await self.store.set_destination_retry_timings(
self.destination, self.destination,

View file

@ -18,7 +18,7 @@ import resource
logger = logging.getLogger("synapse.app.homeserver") logger = logging.getLogger("synapse.app.homeserver")
def change_resource_limit(soft_file_no): def change_resource_limit(soft_file_no: int) -> None:
try: try:
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)

View file

@ -16,7 +16,7 @@
import time import time
import urllib.parse import urllib.parse
from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union
import jinja2 import jinja2
@ -25,9 +25,9 @@ if TYPE_CHECKING:
def build_jinja_env( def build_jinja_env(
template_search_directories: Iterable[str], template_search_directories: Sequence[str],
config: "HomeServerConfig", config: "HomeServerConfig",
autoescape: Union[bool, Callable[[str], bool], None] = None, autoescape: Union[bool, Callable[[Optional[str]], bool], None] = None,
) -> jinja2.Environment: ) -> jinja2.Environment:
"""Set up a Jinja2 environment to load templates from the given search path """Set up a Jinja2 environment to load templates from the given search path
@ -110,5 +110,5 @@ def _create_mxc_to_http_filter(
return mxc_to_http_filter return mxc_to_http_filter
def _format_ts_filter(value: int, format: str): def _format_ts_filter(value: int, format: str) -> str:
return time.strftime(format, time.localtime(value / 1000)) return time.strftime(format, time.localtime(value / 1000))

View file

@ -14,6 +14,10 @@
import logging import logging
import re import re
import typing
if typing.TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -28,13 +32,13 @@ logger = logging.getLogger(__name__)
MAX_EMAIL_ADDRESS_LENGTH = 500 MAX_EMAIL_ADDRESS_LENGTH = 500
def check_3pid_allowed(hs, medium, address): def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool:
"""Checks whether a given format of 3PID is allowed to be used on this HS """Checks whether a given format of 3PID is allowed to be used on this HS
Args: Args:
hs (synapse.server.HomeServer): server hs: server
medium (str): 3pid medium - e.g. email, msisdn medium: 3pid medium - e.g. email, msisdn
address (str): address within that medium (e.g. "wotan@matrix.org") address: address within that medium (e.g. "wotan@matrix.org")
msisdns need to first have been canonicalised msisdns need to first have been canonicalised
Returns: Returns:
bool: whether the 3PID medium/address is allowed to be added to this HS bool: whether the 3PID medium/address is allowed to be added to this HS

View file

@ -19,7 +19,7 @@ import subprocess
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_version_string(module): def get_version_string(module) -> str:
"""Given a module calculate a git-aware version string for it. """Given a module calculate a git-aware version string for it.
If called on a module not in a git checkout will return `__verison__`. If called on a module not in a git checkout will return `__verison__`.

View file

@ -11,38 +11,41 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Generic, List, TypeVar
T = TypeVar("T")
class _Entry: class _Entry(Generic[T]):
__slots__ = ["end_key", "queue"] __slots__ = ["end_key", "queue"]
def __init__(self, end_key): def __init__(self, end_key: int) -> None:
self.end_key = end_key self.end_key: int = end_key
self.queue = [] self.queue: List[T] = []
class WheelTimer: class WheelTimer(Generic[T]):
"""Stores arbitrary objects that will be returned after their timers have """Stores arbitrary objects that will be returned after their timers have
expired. expired.
""" """
def __init__(self, bucket_size=5000): def __init__(self, bucket_size: int = 5000) -> None:
""" """
Args: Args:
bucket_size (int): Size of buckets in ms. Corresponds roughly to the bucket_size: Size of buckets in ms. Corresponds roughly to the
accuracy of the timer. accuracy of the timer.
""" """
self.bucket_size = bucket_size self.bucket_size: int = bucket_size
self.entries = [] self.entries: List[_Entry[T]] = []
self.current_tick = 0 self.current_tick: int = 0
def insert(self, now, obj, then): def insert(self, now: int, obj: T, then: int) -> None:
"""Inserts object into timer. """Inserts object into timer.
Args: Args:
now (int): Current time in msec now: Current time in msec
obj (object): Object to be inserted obj: Object to be inserted
then (int): When to return the object strictly after. then: When to return the object strictly after.
""" """
then_key = int(then / self.bucket_size) + 1 then_key = int(then / self.bucket_size) + 1
@ -70,7 +73,7 @@ class WheelTimer:
self.entries[-1].queue.append(obj) self.entries[-1].queue.append(obj)
def fetch(self, now): def fetch(self, now: int) -> List[T]:
"""Fetch any objects that have timed out """Fetch any objects that have timed out
Args: Args:
@ -87,5 +90,5 @@ class WheelTimer:
return ret return ret
def __len__(self): def __len__(self) -> int:
return sum(len(entry.queue) for entry in self.entries) return sum(len(entry.queue) for entry in self.entries)

View file

@ -734,9 +734,9 @@ class TestTransportLayerServer(JsonResource):
FederationRateLimitConfig( FederationRateLimitConfig(
window_size=1, window_size=1,
sleep_limit=1, sleep_limit=1,
sleep_msec=1, sleep_delay=1,
reject_limit=1000, reject_limit=1000,
concurrent_requests=1000, concurrent=1000,
), ),
) )