Add types to synapse.util. (#10601)

This commit is contained in:
reivilibre 2021-09-10 17:03:18 +01:00 committed by GitHub
parent ceab5a4bfa
commit 524b8ead77
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
41 changed files with 400 additions and 253 deletions

1
changelog.d/10601.misc Normal file
View file

@ -0,0 +1 @@
Add type annotations to the synapse.util package.

View file

@ -74,17 +74,7 @@ files =
synapse/storage/util,
synapse/streams,
synapse/types.py,
synapse/util/async_helpers.py,
synapse/util/caches,
synapse/util/daemonize.py,
synapse/util/hash.py,
synapse/util/iterutils.py,
synapse/util/linked_list.py,
synapse/util/metrics.py,
synapse/util/macaroons.py,
synapse/util/module_loader.py,
synapse/util/msisdn.py,
synapse/util/stringutils.py,
synapse/util,
synapse/visibility.py,
tests/replication,
tests/test_event_auth.py,
@ -102,6 +92,69 @@ files =
[mypy-synapse.rest.client.*]
disallow_untyped_defs = True
[mypy-synapse.util.batching_queue]
disallow_untyped_defs = True
[mypy-synapse.util.caches.dictionary_cache]
disallow_untyped_defs = True
[mypy-synapse.util.file_consumer]
disallow_untyped_defs = True
[mypy-synapse.util.frozenutils]
disallow_untyped_defs = True
[mypy-synapse.util.hash]
disallow_untyped_defs = True
[mypy-synapse.util.httpresourcetree]
disallow_untyped_defs = True
[mypy-synapse.util.iterutils]
disallow_untyped_defs = True
[mypy-synapse.util.linked_list]
disallow_untyped_defs = True
[mypy-synapse.util.logcontext]
disallow_untyped_defs = True
[mypy-synapse.util.logformatter]
disallow_untyped_defs = True
[mypy-synapse.util.macaroons]
disallow_untyped_defs = True
[mypy-synapse.util.manhole]
disallow_untyped_defs = True
[mypy-synapse.util.module_loader]
disallow_untyped_defs = True
[mypy-synapse.util.msisdn]
disallow_untyped_defs = True
[mypy-synapse.util.ratelimitutils]
disallow_untyped_defs = True
[mypy-synapse.util.retryutils]
disallow_untyped_defs = True
[mypy-synapse.util.rlimit]
disallow_untyped_defs = True
[mypy-synapse.util.stringutils]
disallow_untyped_defs = True
[mypy-synapse.util.templates]
disallow_untyped_defs = True
[mypy-synapse.util.threepids]
disallow_untyped_defs = True
[mypy-synapse.util.wheel_timer]
disallow_untyped_defs = True
[mypy-pymacaroons.*]
ignore_missing_imports = True

View file

@ -73,4 +73,4 @@ class RedisFactory(protocol.ReconnectingClientFactory):
def buildProtocol(self, addr) -> RedisProtocol: ...
class SubscriberFactory(RedisFactory):
def __init__(self): ...
def __init__(self) -> None: ...

View file

@ -46,7 +46,7 @@ class Ratelimiter:
# * How many times an action has occurred since a point in time
# * The point in time
# * The rate_hz of this particular entry. This can vary per request
self.actions: OrderedDict[Hashable, Tuple[float, int, float]] = OrderedDict()
self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict()
async def can_do_action(
self,
@ -56,7 +56,7 @@ class Ratelimiter:
burst_count: Optional[int] = None,
update: bool = True,
n_actions: int = 1,
_time_now_s: Optional[int] = None,
_time_now_s: Optional[float] = None,
) -> Tuple[bool, float]:
"""Can the entity (e.g. user or IP address) perform the action?
@ -160,7 +160,7 @@ class Ratelimiter:
return allowed, time_allowed
def _prune_message_counts(self, time_now_s: int):
def _prune_message_counts(self, time_now_s: float):
"""Remove message count entries that have not exceeded their defined
rate_hz limit
@ -188,7 +188,7 @@ class Ratelimiter:
burst_count: Optional[int] = None,
update: bool = True,
n_actions: int = 1,
_time_now_s: Optional[int] = None,
_time_now_s: Optional[float] = None,
):
"""Checks if an action can be performed. If not, raises a LimitExceededError

View file

@ -14,6 +14,8 @@
from typing import Dict, Optional
import attr
from ._base import Config
@ -29,18 +31,13 @@ class RateLimitConfig:
self.burst_count = int(config.get("burst_count", defaults["burst_count"]))
@attr.s(auto_attribs=True)
class FederationRateLimitConfig:
_items_and_default = {
"window_size": 1000,
"sleep_limit": 10,
"sleep_delay": 500,
"reject_limit": 50,
"concurrent": 3,
}
def __init__(self, **kwargs):
for i in self._items_and_default.keys():
setattr(self, i, kwargs.get(i) or self._items_and_default[i])
window_size: int = 1000
sleep_limit: int = 10
sleep_delay: int = 500
reject_limit: int = 50
concurrent: int = 3
class RatelimitConfig(Config):
@ -69,11 +66,15 @@ class RatelimitConfig(Config):
else:
self.rc_federation = FederationRateLimitConfig(
**{
k: v
for k, v in {
"window_size": config.get("federation_rc_window_size"),
"sleep_limit": config.get("federation_rc_sleep_limit"),
"sleep_delay": config.get("federation_rc_sleep_delay"),
"reject_limit": config.get("federation_rc_reject_limit"),
"concurrent": config.get("federation_rc_concurrent"),
}.items()
if v is not None
}
)

View file

@ -22,6 +22,7 @@ from prometheus_client import Counter
from typing_extensions import Literal
from twisted.internet import defer
from twisted.internet.interfaces import IDelayedCall
import synapse.metrics
from synapse.api.presence import UserPresenceState
@ -284,7 +285,9 @@ class FederationSender(AbstractFederationSender):
)
# wake up destinations that have outstanding PDUs to be caught up
self._catchup_after_startup_timer = self.clock.call_later(
self._catchup_after_startup_timer: Optional[
IDelayedCall
] = self.clock.call_later(
CATCH_UP_STARTUP_DELAY_SEC,
run_as_background_process,
"wake_destinations_needing_catchup",
@ -406,7 +409,7 @@ class FederationSender(AbstractFederationSender):
now = self.clock.time_msec()
ts = await self.store.get_received_ts(event.event_id)
assert ts is not None
synapse.metrics.event_processing_lag_by_event.labels(
"federation_sender"
).observe((now - ts) / 1000)
@ -435,6 +438,7 @@ class FederationSender(AbstractFederationSender):
if events:
now = self.clock.time_msec()
ts = await self.store.get_received_ts(events[-1].event_id)
assert ts is not None
synapse.metrics.event_processing_lag.labels(
"federation_sender"

View file

@ -398,6 +398,7 @@ class AccountValidityHandler:
"""
now = self.clock.time_msec()
if expiration_ts is None:
assert self._account_validity_period is not None
expiration_ts = now + self._account_validity_period
await self.store.set_account_validity_for_user(

View file

@ -131,6 +131,8 @@ class ApplicationServicesHandler:
now = self.clock.time_msec()
ts = await self.store.get_received_ts(event.event_id)
assert ts is not None
synapse.metrics.event_processing_lag_by_event.labels(
"appservice_sender"
).observe((now - ts) / 1000)
@ -166,6 +168,7 @@ class ApplicationServicesHandler:
if events:
now = self.clock.time_msec()
ts = await self.store.get_received_ts(events[-1].event_id)
assert ts is not None
synapse.metrics.event_processing_lag.labels(
"appservice_sender"

View file

@ -28,6 +28,7 @@ from bisect import bisect
from contextlib import contextmanager
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
Dict,
@ -615,7 +616,7 @@ class PresenceHandler(BasePresenceHandler):
super().__init__(hs)
self.hs = hs
self.server_name = hs.hostname
self.wheel_timer = WheelTimer()
self.wheel_timer: WheelTimer[str] = WheelTimer()
self.notifier = hs.get_notifier()
self._presence_enabled = hs.config.use_presence
@ -924,7 +925,7 @@ class PresenceHandler(BasePresenceHandler):
prev_state = await self.current_state_for_user(user_id)
new_fields = {"last_active_ts": self.clock.time_msec()}
new_fields: Dict[str, Any] = {"last_active_ts": self.clock.time_msec()}
if prev_state.state == PresenceState.UNAVAILABLE:
new_fields["state"] = PresenceState.ONLINE

View file

@ -73,7 +73,7 @@ class FollowerTypingHandler:
self._room_typing: Dict[str, Set[str]] = {}
self._member_last_federation_poke: Dict[RoomMember, int] = {}
self.wheel_timer = WheelTimer(bucket_size=5000)
self.wheel_timer: WheelTimer[RoomMember] = WheelTimer(bucket_size=5000)
self._latest_room_serial = 0
self.clock.looping_call(self._handle_timeouts, 5000)

View file

@ -330,11 +330,11 @@ class UsernameAvailabilityRestServlet(RestServlet):
# Artificially delay requests if rate > sleep_limit/window_size
sleep_limit=1,
# Amount of artificial delay to apply
sleep_msec=1000,
sleep_delay=1000,
# Error with 429 if more than reject_limit requests are queued
reject_limit=1,
# Allow 1 request at a time
concurrent_requests=1,
concurrent=1,
),
)
@ -763,7 +763,10 @@ class RegisterRestServlet(RestServlet):
Returns:
dictionary for response from /register
"""
result = {"user_id": user_id, "home_server": self.hs.hostname}
result: JsonDict = {
"user_id": user_id,
"home_server": self.hs.hostname,
}
if not params.get("inhibit_login", False):
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
@ -814,7 +817,7 @@ class RegisterRestServlet(RestServlet):
user_id, device_id, initial_display_name, is_guest=True
)
result = {
result: JsonDict = {
"user_id": user_id,
"device_id": device_id,
"access_token": access_token,

View file

@ -52,7 +52,7 @@ class NewUserConsentResource(DirectServeHtmlResource):
yield hs.config.sso.sso_template_dir
yield hs.config.sso.default_template_dir
self._jinja_env = build_jinja_env(template_search_dirs(), hs.config)
self._jinja_env = build_jinja_env(list(template_search_dirs()), hs.config)
async def _async_render_GET(self, request: Request) -> None:
try:

View file

@ -80,7 +80,7 @@ class AccountDetailsResource(DirectServeHtmlResource):
yield hs.config.sso.sso_template_dir
yield hs.config.sso.default_template_dir
self._jinja_env = build_jinja_env(template_search_dirs(), hs.config)
self._jinja_env = build_jinja_env(list(template_search_dirs()), hs.config)
async def _async_render_GET(self, request: Request) -> None:
try:

View file

@ -1091,6 +1091,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
delta equal to 10% of the validity period.
"""
now_ms = self._clock.time_msec()
assert self._account_validity_period is not None
expiration_ts = now_ms + self._account_validity_period
if use_delta:

View file

@ -38,6 +38,7 @@ from twisted.internet.interfaces import (
IReactorCore,
IReactorPluggableNameResolver,
IReactorTCP,
IReactorThreads,
IReactorTime,
)
@ -63,7 +64,12 @@ JsonDict = Dict[str, Any]
# Note that this seems to require inheriting *directly* from Interface in order
# for mypy-zope to realize it is an interface.
class ISynapseReactor(
IReactorTCP, IReactorPluggableNameResolver, IReactorTime, IReactorCore, Interface
IReactorTCP,
IReactorPluggableNameResolver,
IReactorTime,
IReactorCore,
IReactorThreads,
Interface,
):
"""The interfaces necessary for Synapse to function."""

View file

@ -15,27 +15,35 @@
import json
import logging
import re
from typing import Pattern
import typing
from typing import Any, Callable, Dict, Generator, Pattern
import attr
from frozendict import frozendict
from twisted.internet import defer, task
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IDelayedCall, IReactorTime
from twisted.internet.task import LoopingCall
from twisted.python.failure import Failure
from synapse.logging import context
if typing.TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
_WILDCARD_RUN = re.compile(r"([\?\*]+)")
def _reject_invalid_json(val):
def _reject_invalid_json(val: Any) -> None:
"""Do not allow Infinity, -Infinity, or NaN values in JSON."""
raise ValueError("Invalid JSON value: '%s'" % val)
def _handle_frozendict(obj):
def _handle_frozendict(obj: Any) -> Dict[Any, Any]:
"""Helper for json_encoder. Makes frozendicts serializable by returning
the underlying dict
"""
@ -60,10 +68,10 @@ json_encoder = json.JSONEncoder(
json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
def unwrapFirstError(failure):
def unwrapFirstError(failure: Failure) -> Failure:
# defer.gatherResults and DeferredLists wrap failures.
failure.trap(defer.FirstError)
return failure.value.subFailure
return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations
@attr.s(slots=True)
@ -75,25 +83,25 @@ class Clock:
reactor: The Twisted reactor to use.
"""
_reactor = attr.ib()
_reactor: IReactorTime = attr.ib()
@defer.inlineCallbacks
def sleep(self, seconds):
d = defer.Deferred()
@defer.inlineCallbacks # type: ignore[arg-type] # Issue in Twisted's type annotations
def sleep(self, seconds: float) -> "Generator[Deferred[float], Any, Any]":
d: defer.Deferred[float] = defer.Deferred()
with context.PreserveLoggingContext():
self._reactor.callLater(seconds, d.callback, seconds)
res = yield d
return res
def time(self):
def time(self) -> float:
"""Returns the current system time in seconds since epoch."""
return self._reactor.seconds()
def time_msec(self):
def time_msec(self) -> int:
"""Returns the current system time in milliseconds since epoch."""
return int(self.time() * 1000)
def looping_call(self, f, msec, *args, **kwargs):
def looping_call(self, f: Callable, msec: float, *args, **kwargs) -> LoopingCall:
"""Call a function repeatedly.
Waits `msec` initially before calling `f` for the first time.
@ -102,8 +110,8 @@ class Clock:
other than trivial, you probably want to wrap it in run_as_background_process.
Args:
f(function): The function to call repeatedly.
msec(float): How long to wait between calls in milliseconds.
f: The function to call repeatedly.
msec: How long to wait between calls in milliseconds.
*args: Postional arguments to pass to function.
**kwargs: Key arguments to pass to function.
"""
@ -113,7 +121,7 @@ class Clock:
d.addErrback(log_failure, "Looping call died", consumeErrors=False)
return call
def call_later(self, delay, callback, *args, **kwargs):
def call_later(self, delay, callback, *args, **kwargs) -> IDelayedCall:
"""Call something later
Note that the function will be called with no logcontext, so if it is anything
@ -133,7 +141,7 @@ class Clock:
with context.PreserveLoggingContext():
return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs)
def cancel_call_later(self, timer, ignore_errs=False):
def cancel_call_later(self, timer: IDelayedCall, ignore_errs: bool = False) -> None:
try:
timer.cancel()
except Exception:

View file

@ -37,6 +37,7 @@ import attr
from typing_extensions import ContextManager
from twisted.internet import defer
from twisted.internet.base import ReactorBase
from twisted.internet.defer import CancelledError
from twisted.internet.interfaces import IReactorTime
from twisted.python import failure
@ -268,6 +269,7 @@ class Linearizer:
if not clock:
from twisted.internet import reactor
assert isinstance(reactor, ReactorBase)
clock = Clock(reactor)
self._clock = clock
self.max_count = max_count
@ -411,7 +413,7 @@ class ReadWriteLock:
# writers and readers have been resolved. The new writer replaces the latest
# writer.
def __init__(self):
def __init__(self) -> None:
# Latest readers queued
self.key_to_current_readers: Dict[str, Set[defer.Deferred]] = {}
@ -503,7 +505,7 @@ def timeout_deferred(
timed_out = [False]
def time_it_out():
def time_it_out() -> None:
timed_out[0] = True
try:
@ -550,19 +552,21 @@ def timeout_deferred(
return new_d
# This class can't be generic because it uses slots with attrs.
# See: https://github.com/python-attrs/attrs/issues/313
@attr.s(slots=True, frozen=True)
class DoneAwaitable:
class DoneAwaitable: # should be: Generic[R]
"""Simple awaitable that returns the provided value."""
value = attr.ib()
value = attr.ib(type=Any) # should be: R
def __await__(self):
return self
def __iter__(self):
def __iter__(self) -> "DoneAwaitable":
return self
def __next__(self):
def __next__(self) -> None:
raise StopIteration(self.value)

View file

@ -122,7 +122,7 @@ class BatchingQueue(Generic[V, R]):
# First we create a defer and add it and the value to the list of
# pending items.
d = defer.Deferred()
d: defer.Deferred[R] = defer.Deferred()
self._next_values.setdefault(key, []).append((value, d))
# If we're not currently processing the key fire off a background

View file

@ -64,32 +64,32 @@ class CacheMetric:
evicted_size = attr.ib(default=0)
memory_usage = attr.ib(default=None)
def inc_hits(self):
def inc_hits(self) -> None:
self.hits += 1
def inc_misses(self):
def inc_misses(self) -> None:
self.misses += 1
def inc_evictions(self, size=1):
def inc_evictions(self, size: int = 1) -> None:
self.evicted_size += size
def inc_memory_usage(self, memory: int):
def inc_memory_usage(self, memory: int) -> None:
if self.memory_usage is None:
self.memory_usage = 0
self.memory_usage += memory
def dec_memory_usage(self, memory: int):
def dec_memory_usage(self, memory: int) -> None:
self.memory_usage -= memory
def clear_memory_usage(self):
def clear_memory_usage(self) -> None:
if self.memory_usage is not None:
self.memory_usage = 0
def describe(self):
return []
def collect(self):
def collect(self) -> None:
try:
if self._cache_type == "response_cache":
response_cache_size.labels(self._cache_name).set(len(self._cache))

View file

@ -93,7 +93,7 @@ class DeferredCache(Generic[KT, VT]):
TreeCache, "MutableMapping[KT, CacheEntry]"
] = cache_type()
def metrics_cb():
def metrics_cb() -> None:
cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
# cache is used for completed results and maps to the result itself, rather than
@ -113,7 +113,7 @@ class DeferredCache(Generic[KT, VT]):
def max_entries(self):
return self.cache.max_size
def check_thread(self):
def check_thread(self) -> None:
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
@ -235,7 +235,7 @@ class DeferredCache(Generic[KT, VT]):
self._pending_deferred_cache[key] = entry
def compare_and_pop():
def compare_and_pop() -> bool:
"""Check if our entry is still the one in _pending_deferred_cache, and
if so, pop it.
@ -256,7 +256,7 @@ class DeferredCache(Generic[KT, VT]):
return False
def cb(result):
def cb(result) -> None:
if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
@ -268,7 +268,7 @@ class DeferredCache(Generic[KT, VT]):
# not have been. Either way, let's double-check now.
entry.invalidate()
def eb(_fail):
def eb(_fail) -> None:
compare_and_pop()
entry.invalidate()
@ -314,7 +314,7 @@ class DeferredCache(Generic[KT, VT]):
for entry in iterate_tree_cache_entry(entry):
entry.invalidate()
def invalidate_all(self):
def invalidate_all(self) -> None:
self.check_thread()
self.cache.clear()
for entry in self._pending_deferred_cache.values():
@ -332,7 +332,7 @@ class CacheEntry:
self.callbacks = set(callbacks)
self.invalidated = False
def invalidate(self):
def invalidate(self) -> None:
if not self.invalidated:
self.invalidated = True
for callback in self.callbacks:

View file

@ -27,10 +27,14 @@ logger = logging.getLogger(__name__)
KT = TypeVar("KT")
# The type of the dictionary keys.
DKT = TypeVar("DKT")
# The type of the dictionary values.
DV = TypeVar("DV")
# This class can't be generic because it uses slots with attrs.
# See: https://github.com/python-attrs/attrs/issues/313
@attr.s(slots=True)
class DictionaryEntry:
class DictionaryEntry: # should be: Generic[DKT, DV].
"""Returned when getting an entry from the cache
Attributes:
@ -43,10 +47,10 @@ class DictionaryEntry:
"""
full = attr.ib(type=bool)
known_absent = attr.ib()
value = attr.ib()
known_absent = attr.ib(type=Set[Any]) # should be: Set[DKT]
value = attr.ib(type=Dict[Any, Any]) # should be: Dict[DKT, DV]
def __len__(self):
def __len__(self) -> int:
return len(self.value)
@ -56,7 +60,7 @@ class _Sentinel(enum.Enum):
sentinel = object()
class DictionaryCache(Generic[KT, DKT]):
class DictionaryCache(Generic[KT, DKT, DV]):
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
fetching a subset of dictionary keys for a particular key.
"""
@ -87,7 +91,7 @@ class DictionaryCache(Generic[KT, DKT]):
Args:
key
dict_key: If given a set of keys then return only those keys
dict_keys: If given a set of keys then return only those keys
that exist in the cache.
Returns:
@ -125,7 +129,7 @@ class DictionaryCache(Generic[KT, DKT]):
self,
sequence: int,
key: KT,
value: Dict[DKT, Any],
value: Dict[DKT, DV],
fetched_keys: Optional[Set[DKT]] = None,
) -> None:
"""Updates the entry in the cache
@ -151,15 +155,15 @@ class DictionaryCache(Generic[KT, DKT]):
self._update_or_insert(key, value, fetched_keys)
def _update_or_insert(
self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]
self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]
) -> None:
# We pop and reinsert as we need to tell the cache the size may have
# changed
entry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
entry: DictionaryEntry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
entry.value.update(value)
entry.known_absent.update(known_absent)
self.cache[key] = entry
def _insert(self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]) -> None:
def _insert(self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]) -> None:
self.cache[key] = DictionaryEntry(True, known_absent, value)

View file

@ -35,6 +35,7 @@ from typing import (
from typing_extensions import Literal
from twisted.internet import reactor
from twisted.internet.interfaces import IReactorTime
from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import wrap_as_background_process
@ -341,7 +342,7 @@ class LruCache(Generic[KT, VT]):
# Default `clock` to something sensible. Note that we rename it to
# `real_clock` so that mypy doesn't think its still `Optional`.
if clock is None:
real_clock = Clock(reactor)
real_clock = Clock(cast(IReactorTime, reactor))
else:
real_clock = clock
@ -384,7 +385,7 @@ class LruCache(Generic[KT, VT]):
lock = threading.Lock()
def evict():
def evict() -> None:
while cache_len() > self.max_size:
# Get the last node in the list (i.e. the oldest node).
todelete = list_root.prev_node

View file

@ -195,7 +195,7 @@ class StreamChangeCache:
for entity in r:
del self._entity_to_key[entity]
def _evict(self):
def _evict(self) -> None:
while len(self._cache) > self._max_size:
k, r = self._cache.popitem(0)
self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)

View file

@ -35,17 +35,17 @@ class TreeCache:
root = {key_1: {key_2: _value}}
"""
def __init__(self):
self.size = 0
def __init__(self) -> None:
self.size: int = 0
self.root = TreeCacheNode()
def __setitem__(self, key, value):
return self.set(key, value)
def __setitem__(self, key, value) -> None:
self.set(key, value)
def __contains__(self, key):
def __contains__(self, key) -> bool:
return self.get(key, SENTINEL) is not SENTINEL
def set(self, key, value):
def set(self, key, value) -> None:
if isinstance(value, TreeCacheNode):
# this would mean we couldn't tell where our tree ended and the value
# started.
@ -73,7 +73,7 @@ class TreeCache:
return default
return node.get(key[-1], default)
def clear(self):
def clear(self) -> None:
self.size = 0
self.root = TreeCacheNode()
@ -128,7 +128,7 @@ class TreeCache:
def values(self):
return iterate_tree_cache_entry(self.root)
def __len__(self):
def __len__(self) -> int:
return self.size

View file

@ -126,7 +126,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
signal.signal(signal.SIGTERM, sigterm)
# Cleanup pid file at exit.
def exit():
def exit() -> None:
logger.warning("Stopping daemon.")
os.remove(pid_file)
sys.exit(0)

View file

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Callable, Dict, List
from twisted.internet import defer
@ -37,11 +38,11 @@ class Distributor:
model will do for today.
"""
def __init__(self):
self.signals = {}
self.pre_registration = {}
def __init__(self) -> None:
self.signals: Dict[str, Signal] = {}
self.pre_registration: Dict[str, List[Callable]] = {}
def declare(self, name):
def declare(self, name: str) -> None:
if name in self.signals:
raise KeyError("%r already has a signal named %s" % (self, name))
@ -52,7 +53,7 @@ class Distributor:
for observer in self.pre_registration[name]:
signal.observe(observer)
def observe(self, name, observer):
def observe(self, name: str, observer: Callable) -> None:
if name in self.signals:
self.signals[name].observe(observer)
else:
@ -62,7 +63,7 @@ class Distributor:
self.pre_registration[name] = []
self.pre_registration[name].append(observer)
def fire(self, name, *args, **kwargs):
def fire(self, name: str, *args, **kwargs) -> None:
"""Dispatches the given signal to the registered observers.
Runs the observers as a background process. Does not return a deferred.
@ -83,18 +84,18 @@ class Signal:
method into all of the observers.
"""
def __init__(self, name):
self.name = name
self.observers = []
def __init__(self, name: str):
self.name: str = name
self.observers: List[Callable] = []
def observe(self, observer):
def observe(self, observer: Callable) -> None:
"""Adds a new callable to the observer list which will be invoked by
the 'fire' method.
Each observer callable may return a Deferred."""
self.observers.append(observer)
def fire(self, *args, **kwargs):
def fire(self, *args, **kwargs) -> "defer.Deferred[List[Any]]":
"""Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is
not an error to fire a signal with no observers.

View file

@ -13,10 +13,14 @@
# limitations under the License.
import queue
from typing import BinaryIO, Optional, Union, cast
from twisted.internet import threads
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IPullProducer, IPushProducer
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import ISynapseReactor
class BackgroundFileConsumer:
@ -24,9 +28,9 @@ class BackgroundFileConsumer:
and pull producers
Args:
file_obj (file): The file like object to write to. Closed when
file_obj: The file like object to write to. Closed when
finished.
reactor (twisted.internet.reactor): the Twisted reactor to use
reactor: the Twisted reactor to use
"""
# For PushProducers pause if we have this many unwritten slices
@ -34,13 +38,13 @@ class BackgroundFileConsumer:
# And resume once the size of the queue is less than this
_RESUME_ON_QUEUE_SIZE = 2
def __init__(self, file_obj, reactor):
self._file_obj = file_obj
def __init__(self, file_obj: BinaryIO, reactor: ISynapseReactor) -> None:
self._file_obj: BinaryIO = file_obj
self._reactor = reactor
self._reactor: ISynapseReactor = reactor
# Producer we're registered with
self._producer = None
self._producer: Optional[Union[IPushProducer, IPullProducer]] = None
# True if PushProducer, false if PullProducer
self.streaming = False
@ -51,20 +55,22 @@ class BackgroundFileConsumer:
# Queue of slices of bytes to be written. When producer calls
# unregister a final None is sent.
self._bytes_queue = queue.Queue()
self._bytes_queue: queue.Queue[Optional[bytes]] = queue.Queue()
# Deferred that is resolved when finished writing
self._finished_deferred = None
self._finished_deferred: Optional[Deferred[None]] = None
# If the _writer thread throws an exception it gets stored here.
self._write_exception = None
self._write_exception: Optional[Exception] = None
def registerProducer(self, producer, streaming):
def registerProducer(
self, producer: Union[IPushProducer, IPullProducer], streaming: bool
) -> None:
"""Part of IConsumer interface
Args:
producer (IProducer)
streaming (bool): True if push based producer, False if pull
producer
streaming: True if push based producer, False if pull
based.
"""
if self._producer:
@ -81,29 +87,33 @@ class BackgroundFileConsumer:
if not streaming:
self._producer.resumeProducing()
def unregisterProducer(self):
def unregisterProducer(self) -> None:
"""Part of IProducer interface"""
self._producer = None
assert self._finished_deferred is not None
if not self._finished_deferred.called:
self._bytes_queue.put_nowait(None)
def write(self, bytes):
def write(self, write_bytes: bytes) -> None:
"""Part of IProducer interface"""
if self._write_exception:
raise self._write_exception
assert self._finished_deferred is not None
if self._finished_deferred.called:
raise Exception("consumer has closed")
self._bytes_queue.put_nowait(bytes)
self._bytes_queue.put_nowait(write_bytes)
# If this is a PushProducer and the queue is getting behind
# then we pause the producer.
if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE:
self._paused_producer = True
self._producer.pauseProducing()
assert self._producer is not None
# cast safe because `streaming` means this is an IPushProducer
cast(IPushProducer, self._producer).pauseProducing()
def _writer(self):
def _writer(self) -> None:
"""This is run in a background thread to write to the file."""
try:
while self._producer or not self._bytes_queue.empty():
@ -130,11 +140,11 @@ class BackgroundFileConsumer:
finally:
self._file_obj.close()
def wait(self):
def wait(self) -> "Deferred[None]":
"""Returns a deferred that resolves when finished writing to file"""
return make_deferred_yieldable(self._finished_deferred)
def _resume_paused_producer(self):
def _resume_paused_producer(self) -> None:
"""Gets called if we should resume producing after being paused"""
if self._paused_producer and self._producer:
self._paused_producer = False

View file

@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from frozendict import frozendict
def freeze(o):
def freeze(o: Any) -> Any:
if isinstance(o, dict):
return frozendict({k: freeze(v) for k, v in o.items()})
@ -33,7 +34,7 @@ def freeze(o):
return o
def unfreeze(o):
def unfreeze(o: Any) -> Any:
if isinstance(o, (dict, frozendict)):
return {k: unfreeze(v) for k, v in o.items()}

View file

@ -13,42 +13,43 @@
# limitations under the License.
import logging
from typing import Dict
from twisted.web.resource import NoResource
from twisted.web.resource import NoResource, Resource
logger = logging.getLogger(__name__)
def create_resource_tree(desired_tree, root_resource):
def create_resource_tree(
desired_tree: Dict[str, Resource], root_resource: Resource
) -> Resource:
"""Create the resource tree for this homeserver.
This in unduly complicated because Twisted does not support putting
child resources more than 1 level deep at a time.
Args:
web_client (bool): True to enable the web client.
root_resource (twisted.web.resource.Resource): The root
resource to add the tree to.
desired_tree: Dict from desired paths to desired resources.
root_resource: The root resource to add the tree to.
Returns:
twisted.web.resource.Resource: the ``root_resource`` with a tree of
child resources added to it.
The ``root_resource`` with a tree of child resources added to it.
"""
# ideally we'd just use getChild and putChild but getChild doesn't work
# unless you give it a Request object IN ADDITION to the name :/ So
# instead, we'll store a copy of this mapping so we can actually add
# extra resources to existing nodes. See self._resource_id for the key.
resource_mappings = {}
for full_path, res in desired_tree.items():
resource_mappings: Dict[str, Resource] = {}
for full_path_str, res in desired_tree.items():
# twisted requires all resources to be bytes
full_path = full_path.encode("utf-8")
full_path = full_path_str.encode("utf-8")
logger.info("Attaching %s to path %s", res, full_path)
last_resource = root_resource
for path_seg in full_path.split(b"/")[1:-1]:
if path_seg not in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource"
child_resource = NoResource()
child_resource: Resource = NoResource()
last_resource.putChild(path_seg, child_resource)
res_id = _resource_id(last_resource, path_seg)
resource_mappings[res_id] = child_resource
@ -83,7 +84,7 @@ def create_resource_tree(desired_tree, root_resource):
return root_resource
def _resource_id(resource, path_seg):
def _resource_id(resource: Resource, path_seg: bytes) -> str:
"""Construct an arbitrary resource ID so you can retrieve the mapping
later.
@ -96,4 +97,4 @@ def _resource_id(resource, path_seg):
Returns:
str: A unique string which can be a key to the child Resource.
"""
return "%s-%s" % (resource, path_seg)
return "%s-%r" % (resource, path_seg)

View file

@ -74,7 +74,7 @@ class ListNode(Generic[P]):
new_node._refs_insert_after(node)
return new_node
def remove_from_list(self):
def remove_from_list(self) -> None:
"""Remove this node from the list."""
with self._LOCK:
self._refs_remove_node_from_list()
@ -84,7 +84,7 @@ class ListNode(Generic[P]):
# immediately rather than at the next GC.
self.cache_entry = None
def move_after(self, node: "ListNode"):
def move_after(self, node: "ListNode") -> None:
"""Move this node from its current location in the list to after the
given node.
"""
@ -103,7 +103,7 @@ class ListNode(Generic[P]):
# Insert self back into the list, after target node
self._refs_insert_after(node)
def _refs_remove_node_from_list(self):
def _refs_remove_node_from_list(self) -> None:
"""Internal method to *just* remove the node from the list, without
e.g. clearing out the cache entry.
"""
@ -122,7 +122,7 @@ class ListNode(Generic[P]):
self.prev_node = None
self.next_node = None
def _refs_insert_after(self, node: "ListNode"):
def _refs_insert_after(self, node: "ListNode") -> None:
"""Internal method to insert the node after the given node."""
# This method should only be called when we're not already in the list.

View file

@ -77,7 +77,7 @@ def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> N
should be considered expired. Normally the current time.
"""
def verify_expiry_caveat(caveat: str):
def verify_expiry_caveat(caveat: str) -> bool:
time_msec = get_time_ms()
prefix = "time < "
if not caveat.startswith(prefix):

View file

@ -15,6 +15,7 @@
import inspect
import sys
import traceback
from typing import Any, Dict, Optional
from twisted.conch import manhole_ssh
from twisted.conch.insults import insults
@ -22,6 +23,9 @@ from twisted.conch.manhole import ColoredManhole, ManholeInterpreter
from twisted.conch.ssh.keys import Key
from twisted.cred import checkers, portal
from twisted.internet import defer
from twisted.internet.protocol import Factory
from synapse.config.server import ManholeConfig
PUBLIC_KEY = (
"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDHhGATaW4KhE23+7nrH4jFx3yLq9OjaEs5"
@ -61,22 +65,22 @@ EddTrx3TNpr1D5m/f+6mnXWrc8u9y1+GNx9yz889xMjIBTBI9KqaaOs=
-----END RSA PRIVATE KEY-----"""
def manhole(settings, globals):
def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> Factory:
"""Starts a ssh listener with password authentication using
the given username and password. Clients connecting to the ssh
listener will find themselves in a colored python shell with
the supplied globals.
Args:
username(str): The username ssh clients should auth with.
password(str): The password ssh clients should auth with.
globals(dict): The variables to expose in the shell.
username: The username ssh clients should auth with.
password: The password ssh clients should auth with.
globals: The variables to expose in the shell.
Returns:
twisted.internet.protocol.Factory: A factory to pass to ``listenTCP``
A factory to pass to ``listenTCP``
"""
username = settings.username
password = settings.password
password = settings.password.encode("ascii")
priv_key = settings.priv_key
if priv_key is None:
priv_key = Key.fromString(PRIVATE_KEY)
@ -84,19 +88,22 @@ def manhole(settings, globals):
if pub_key is None:
pub_key = Key.fromString(PUBLIC_KEY)
if not isinstance(password, bytes):
password = password.encode("ascii")
checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(**{username: password})
rlm = manhole_ssh.TerminalRealm()
rlm.chainedProtocolFactory = lambda: insults.ServerProtocol(
# mypy ignored here because:
# - can't deduce types of lambdas
# - variable is Type[ServerProtocol], expr is Callable[[], ServerProtocol]
rlm.chainedProtocolFactory = lambda: insults.ServerProtocol( # type: ignore[misc,assignment]
SynapseManhole, dict(globals, __name__="__console__")
)
factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
factory.privateKeys[b"ssh-rsa"] = priv_key
factory.publicKeys[b"ssh-rsa"] = pub_key
# conch has the wrong type on these dicts (says bytes to bytes,
# should be bytes to Keys judging by how it's used).
factory.privateKeys[b"ssh-rsa"] = priv_key # type: ignore[assignment]
factory.publicKeys[b"ssh-rsa"] = pub_key # type: ignore[assignment]
return factory
@ -104,7 +111,7 @@ def manhole(settings, globals):
class SynapseManhole(ColoredManhole):
"""Overrides connectionMade to create our own ManholeInterpreter"""
def connectionMade(self):
def connectionMade(self) -> None:
super().connectionMade()
# replace the manhole interpreter with our own impl
@ -114,13 +121,14 @@ class SynapseManhole(ColoredManhole):
class SynapseManholeInterpreter(ManholeInterpreter):
def showsyntaxerror(self, filename=None):
def showsyntaxerror(self, filename: Optional[str] = None) -> None:
"""Display the syntax error that just occurred.
Overrides the base implementation, ignoring sys.excepthook. We always want
any syntax errors to be sent to the terminal, rather than sentry.
"""
type, value, tb = sys.exc_info()
assert value is not None
sys.last_type = type
sys.last_value = value
sys.last_traceback = tb
@ -138,7 +146,7 @@ class SynapseManholeInterpreter(ManholeInterpreter):
lines = traceback.format_exception_only(type, value)
self.write("".join(lines))
def showtraceback(self):
def showtraceback(self) -> None:
"""Display the exception that just occurred.
Overrides the base implementation, ignoring sys.excepthook. We always want
@ -146,14 +154,22 @@ class SynapseManholeInterpreter(ManholeInterpreter):
"""
sys.last_type, sys.last_value, last_tb = ei = sys.exc_info()
sys.last_traceback = last_tb
assert last_tb is not None
try:
# We remove the first stack item because it is our own code.
lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next)
self.write("".join(lines))
finally:
last_tb = ei = None
# On the line below, last_tb and ei appear to be dead.
# It's unclear whether there is a reason behind this line.
# It conceivably could be because an exception raised in this block
# will keep the local frame (containing these local variables) around.
# This was adapted taken from CPython's Lib/code.py; see here:
# https://github.com/python/cpython/blob/4dc4300c686f543d504ab6fa9fe600eaf11bb695/Lib/code.py#L131-L150
last_tb = ei = None # type: ignore
def displayhook(self, obj):
def displayhook(self, obj: Any) -> None:
"""
We override the displayhook so that we automatically convert coroutines
into Deferreds. (Our superclass' displayhook will take care of the rest,

View file

@ -24,7 +24,7 @@ from twisted.python.failure import Failure
_already_patched = False
def do_patch():
def do_patch() -> None:
"""
Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
"""
@ -107,7 +107,7 @@ def do_patch():
_already_patched = True
def _check_yield_points(f: Callable, changes: List[str]):
def _check_yield_points(f: Callable, changes: List[str]) -> Callable:
"""Wraps a generator that is about to be passed to defer.inlineCallbacks
checking that after every yield the log contexts are correct.

View file

@ -15,33 +15,36 @@
import collections
import contextlib
import logging
import typing
from typing import Any, DefaultDict, Iterator, List, Set
from twisted.internet import defer
from synapse.api.errors import LimitExceededError
from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
run_in_background,
)
from synapse.util import Clock
if typing.TYPE_CHECKING:
from contextlib import _GeneratorContextManager
logger = logging.getLogger(__name__)
class FederationRateLimiter:
def __init__(self, clock, config):
"""
Args:
clock (Clock)
config (FederationRateLimitConfig)
"""
def new_limiter():
def __init__(self, clock: Clock, config: FederationRateLimitConfig):
def new_limiter() -> "_PerHostRatelimiter":
return _PerHostRatelimiter(clock=clock, config=config)
self.ratelimiters = collections.defaultdict(new_limiter)
self.ratelimiters: DefaultDict[
str, "_PerHostRatelimiter"
] = collections.defaultdict(new_limiter)
def ratelimit(self, host):
def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]":
"""Used to ratelimit an incoming request from a given host
Example usage:
@ -60,11 +63,11 @@ class FederationRateLimiter:
class _PerHostRatelimiter:
def __init__(self, clock, config):
def __init__(self, clock: Clock, config: FederationRateLimitConfig):
"""
Args:
clock (Clock)
config (FederationRateLimitConfig)
clock
config
"""
self.clock = clock
@ -75,21 +78,23 @@ class _PerHostRatelimiter:
self.concurrent_requests = config.concurrent
# request_id objects for requests which have been slept
self.sleeping_requests = set()
self.sleeping_requests: Set[object] = set()
# map from request_id object to Deferred for requests which are ready
# for processing but have been queued
self.ready_request_queue = collections.OrderedDict()
self.ready_request_queue: collections.OrderedDict[
object, defer.Deferred[None]
] = collections.OrderedDict()
# request id objects for requests which are in progress
self.current_processing = set()
self.current_processing: Set[object] = set()
# times at which we have recently (within the last window_size ms)
# received requests.
self.request_times = []
self.request_times: List[int] = []
@contextlib.contextmanager
def ratelimit(self):
def ratelimit(self) -> "Iterator[defer.Deferred[None]]":
# `contextlib.contextmanager` takes a generator and turns it into a
# context manager. The generator should only yield once with a value
# to be returned by manager.
@ -102,7 +107,7 @@ class _PerHostRatelimiter:
finally:
self._on_exit(request_id)
def _on_enter(self, request_id):
def _on_enter(self, request_id: object) -> "defer.Deferred[None]":
time_now = self.clock.time_msec()
# remove any entries from request_times which aren't within the window
@ -120,9 +125,9 @@ class _PerHostRatelimiter:
self.request_times.append(time_now)
def queue_request():
def queue_request() -> "defer.Deferred[None]":
if len(self.current_processing) >= self.concurrent_requests:
queue_defer = defer.Deferred()
queue_defer: defer.Deferred[None] = defer.Deferred()
self.ready_request_queue[request_id] = queue_defer
logger.info(
"Ratelimiter: queueing request (queue now %i items)",
@ -145,7 +150,7 @@ class _PerHostRatelimiter:
self.sleeping_requests.add(request_id)
def on_wait_finished(_):
def on_wait_finished(_: Any) -> "defer.Deferred[None]":
logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id))
self.sleeping_requests.discard(request_id)
queue_defer = queue_request()
@ -155,19 +160,19 @@ class _PerHostRatelimiter:
else:
ret_defer = queue_request()
def on_start(r):
def on_start(r: object) -> object:
logger.debug("Ratelimit [%s]: Processing req", id(request_id))
self.current_processing.add(request_id)
return r
def on_err(r):
def on_err(r: object) -> object:
# XXX: why is this necessary? this is called before we start
# processing the request so why would the request be in
# current_processing?
self.current_processing.discard(request_id)
return r
def on_both(r):
def on_both(r: object) -> object:
# Ensure that we've properly cleaned up.
self.sleeping_requests.discard(request_id)
self.ready_request_queue.pop(request_id, None)
@ -177,7 +182,7 @@ class _PerHostRatelimiter:
ret_defer.addBoth(on_both)
return make_deferred_yieldable(ret_defer)
def _on_exit(self, request_id):
def _on_exit(self, request_id: object) -> None:
logger.debug("Ratelimit [%s]: Processed req", id(request_id))
self.current_processing.discard(request_id)
try:

View file

@ -13,9 +13,13 @@
# limitations under the License.
import logging
import random
from types import TracebackType
from typing import Any, Optional, Type
import synapse.logging.context
from synapse.api.errors import CodeMessageException
from synapse.storage import DataStore
from synapse.util import Clock
logger = logging.getLogger(__name__)
@ -30,17 +34,17 @@ MAX_RETRY_INTERVAL = 2 ** 62
class NotRetryingDestination(Exception):
def __init__(self, retry_last_ts, retry_interval, destination):
def __init__(self, retry_last_ts: int, retry_interval: int, destination: str):
"""Raised by the limiter (and federation client) to indicate that we are
are deliberately not attempting to contact a given server.
Args:
retry_last_ts (int): the unix ts in milliseconds of our last attempt
retry_last_ts: the unix ts in milliseconds of our last attempt
to contact the server. 0 indicates that the last attempt was
successful or that we've never actually attempted to connect.
retry_interval (int): the time in milliseconds to wait until the next
retry_interval: the time in milliseconds to wait until the next
attempt.
destination (str): the domain in question
destination: the domain in question
"""
msg = "Not retrying server %s." % (destination,)
@ -51,7 +55,13 @@ class NotRetryingDestination(Exception):
self.destination = destination
async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs):
async def get_retry_limiter(
destination: str,
clock: Clock,
store: DataStore,
ignore_backoff: bool = False,
**kwargs: Any,
) -> "RetryDestinationLimiter":
"""For a given destination check if we have previously failed to
send a request there and are waiting before retrying the destination.
If we are not ready to retry the destination, this will raise a
@ -60,10 +70,10 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k
CodeMessageException with code < 500)
Args:
destination (str): name of homeserver
clock (synapse.util.clock): timing source
store (synapse.storage.transactions.TransactionStore): datastore
ignore_backoff (bool): true to ignore the historical backoff data and
destination: name of homeserver
clock: timing source
store: datastore
ignore_backoff: true to ignore the historical backoff data and
try the request anyway. We will still reset the retry_interval on success.
Example usage:
@ -114,13 +124,13 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k
class RetryDestinationLimiter:
def __init__(
self,
destination,
clock,
store,
failure_ts,
retry_interval,
backoff_on_404=False,
backoff_on_failure=True,
destination: str,
clock: Clock,
store: DataStore,
failure_ts: Optional[int],
retry_interval: int,
backoff_on_404: bool = False,
backoff_on_failure: bool = True,
):
"""Marks the destination as "down" if an exception is thrown in the
context, except for CodeMessageException with code < 500.
@ -128,17 +138,17 @@ class RetryDestinationLimiter:
If no exception is raised, marks the destination as "up".
Args:
destination (str)
clock (Clock)
store (DataStore)
failure_ts (int|None): when this destination started failing (in ms since
destination
clock
store
failure_ts: when this destination started failing (in ms since
the epoch), or zero if the last request was successful
retry_interval (int): The next retry interval taken from the
retry_interval: The next retry interval taken from the
database in milliseconds, or zero if the last request was
successful.
backoff_on_404 (bool): Back off if we get a 404
backoff_on_404: Back off if we get a 404
backoff_on_failure (bool): set to False if we should not increase the
backoff_on_failure: set to False if we should not increase the
retry interval on a failure.
"""
self.clock = clock
@ -150,10 +160,15 @@ class RetryDestinationLimiter:
self.backoff_on_404 = backoff_on_404
self.backoff_on_failure = backoff_on_failure
def __enter__(self):
def __enter__(self) -> None:
pass
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
valid_err_code = False
if exc_type is None:
valid_err_code = True
@ -161,7 +176,7 @@ class RetryDestinationLimiter:
# avoid treating exceptions which don't derive from Exception as
# failures; this is mostly so as not to catch defer._DefGen.
valid_err_code = True
elif issubclass(exc_type, CodeMessageException):
elif isinstance(exc_val, CodeMessageException):
# Some error codes are perfectly fine for some APIs, whereas other
# APIs may expect to never received e.g. a 404. It's important to
# handle 404 as some remote servers will return a 404 when the HS
@ -216,7 +231,7 @@ class RetryDestinationLimiter:
if self.failure_ts is None:
self.failure_ts = retry_last_ts
async def store_retry_timings():
async def store_retry_timings() -> None:
try:
await self.store.set_destination_retry_timings(
self.destination,

View file

@ -18,7 +18,7 @@ import resource
logger = logging.getLogger("synapse.app.homeserver")
def change_resource_limit(soft_file_no):
def change_resource_limit(soft_file_no: int) -> None:
try:
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)

View file

@ -16,7 +16,7 @@
import time
import urllib.parse
from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union
from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union
import jinja2
@ -25,9 +25,9 @@ if TYPE_CHECKING:
def build_jinja_env(
template_search_directories: Iterable[str],
template_search_directories: Sequence[str],
config: "HomeServerConfig",
autoescape: Union[bool, Callable[[str], bool], None] = None,
autoescape: Union[bool, Callable[[Optional[str]], bool], None] = None,
) -> jinja2.Environment:
"""Set up a Jinja2 environment to load templates from the given search path
@ -110,5 +110,5 @@ def _create_mxc_to_http_filter(
return mxc_to_http_filter
def _format_ts_filter(value: int, format: str):
def _format_ts_filter(value: int, format: str) -> str:
return time.strftime(format, time.localtime(value / 1000))

View file

@ -14,6 +14,10 @@
import logging
import re
import typing
if typing.TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -28,13 +32,13 @@ logger = logging.getLogger(__name__)
MAX_EMAIL_ADDRESS_LENGTH = 500
def check_3pid_allowed(hs, medium, address):
def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool:
"""Checks whether a given format of 3PID is allowed to be used on this HS
Args:
hs (synapse.server.HomeServer): server
medium (str): 3pid medium - e.g. email, msisdn
address (str): address within that medium (e.g. "wotan@matrix.org")
hs: server
medium: 3pid medium - e.g. email, msisdn
address: address within that medium (e.g. "wotan@matrix.org")
msisdns need to first have been canonicalised
Returns:
bool: whether the 3PID medium/address is allowed to be added to this HS

View file

@ -19,7 +19,7 @@ import subprocess
logger = logging.getLogger(__name__)
def get_version_string(module):
def get_version_string(module) -> str:
"""Given a module calculate a git-aware version string for it.
If called on a module not in a git checkout will return `__verison__`.

View file

@ -11,38 +11,41 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Generic, List, TypeVar
T = TypeVar("T")
class _Entry:
class _Entry(Generic[T]):
__slots__ = ["end_key", "queue"]
def __init__(self, end_key):
self.end_key = end_key
self.queue = []
def __init__(self, end_key: int) -> None:
self.end_key: int = end_key
self.queue: List[T] = []
class WheelTimer:
class WheelTimer(Generic[T]):
"""Stores arbitrary objects that will be returned after their timers have
expired.
"""
def __init__(self, bucket_size=5000):
def __init__(self, bucket_size: int = 5000) -> None:
"""
Args:
bucket_size (int): Size of buckets in ms. Corresponds roughly to the
bucket_size: Size of buckets in ms. Corresponds roughly to the
accuracy of the timer.
"""
self.bucket_size = bucket_size
self.entries = []
self.current_tick = 0
self.bucket_size: int = bucket_size
self.entries: List[_Entry[T]] = []
self.current_tick: int = 0
def insert(self, now, obj, then):
def insert(self, now: int, obj: T, then: int) -> None:
"""Inserts object into timer.
Args:
now (int): Current time in msec
obj (object): Object to be inserted
then (int): When to return the object strictly after.
now: Current time in msec
obj: Object to be inserted
then: When to return the object strictly after.
"""
then_key = int(then / self.bucket_size) + 1
@ -70,7 +73,7 @@ class WheelTimer:
self.entries[-1].queue.append(obj)
def fetch(self, now):
def fetch(self, now: int) -> List[T]:
"""Fetch any objects that have timed out
Args:
@ -87,5 +90,5 @@ class WheelTimer:
return ret
def __len__(self):
def __len__(self) -> int:
return sum(len(entry.queue) for entry in self.entries)

View file

@ -734,9 +734,9 @@ class TestTransportLayerServer(JsonResource):
FederationRateLimitConfig(
window_size=1,
sleep_limit=1,
sleep_msec=1,
sleep_delay=1,
reject_limit=1000,
concurrent_requests=1000,
concurrent=1000,
),
)