Add missing types to tests.util. (#14597)

Removes files under tests.util from the ignored by list, then
fully types all tests/util/*.py files.
This commit is contained in:
Patrick Cloke 2022-12-02 12:58:56 -05:00 committed by GitHub
parent fac8a38525
commit acea4d7a2f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 361 additions and 276 deletions

1
changelog.d/14597.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints.

View file

@ -59,16 +59,6 @@ exclude = (?x)
|tests/server_notices/test_resource_limits_server_notices.py |tests/server_notices/test_resource_limits_server_notices.py
|tests/test_state.py |tests/test_state.py
|tests/test_terms_auth.py |tests/test_terms_auth.py
|tests/util/test_async_helpers.py
|tests/util/test_batching_queue.py
|tests/util/test_dict_cache.py
|tests/util/test_expiring_cache.py
|tests/util/test_file_consumer.py
|tests/util/test_linearizer.py
|tests/util/test_logcontext.py
|tests/util/test_lrucache.py
|tests/util/test_rwlock.py
|tests/util/test_wheel_timer.py
)$ )$
[mypy-synapse.federation.transport.client] [mypy-synapse.federation.transport.client]
@ -137,6 +127,9 @@ disallow_untyped_defs = True
[mypy-tests.util.caches.test_descriptors] [mypy-tests.util.caches.test_descriptors]
disallow_untyped_defs = False disallow_untyped_defs = False
[mypy-tests.util.*]
disallow_untyped_defs = True
[mypy-tests.utils] [mypy-tests.utils]
disallow_untyped_defs = True disallow_untyped_defs = True

View file

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import traceback import traceback
from typing import Generator, List, NoReturn, Optional
from parameterized import parameterized_class from parameterized import parameterized_class
@ -41,8 +42,8 @@ from tests.unittest import TestCase
class ObservableDeferredTest(TestCase): class ObservableDeferredTest(TestCase):
def test_succeed(self): def test_succeed(self) -> None:
origin_d = Deferred() origin_d: "Deferred[int]" = Deferred()
observable = ObservableDeferred(origin_d) observable = ObservableDeferred(origin_d)
observer1 = observable.observe() observer1 = observable.observe()
@ -52,16 +53,18 @@ class ObservableDeferredTest(TestCase):
self.assertFalse(observer2.called) self.assertFalse(observer2.called)
# check the first observer is called first # check the first observer is called first
def check_called_first(res): def check_called_first(res: int) -> int:
self.assertFalse(observer2.called) self.assertFalse(observer2.called)
return res return res
observer1.addBoth(check_called_first) observer1.addBoth(check_called_first)
# store the results # store the results
results = [None, None] results: List[Optional[ObservableDeferred[int]]] = [None, None]
def check_val(res, idx): def check_val(
res: ObservableDeferred[int], idx: int
) -> ObservableDeferred[int]:
results[idx] = res results[idx] = res
return res return res
@ -72,8 +75,8 @@ class ObservableDeferredTest(TestCase):
self.assertEqual(results[0], 123, "observer 1 callback result") self.assertEqual(results[0], 123, "observer 1 callback result")
self.assertEqual(results[1], 123, "observer 2 callback result") self.assertEqual(results[1], 123, "observer 2 callback result")
def test_failure(self): def test_failure(self) -> None:
origin_d = Deferred() origin_d: Deferred = Deferred()
observable = ObservableDeferred(origin_d, consumeErrors=True) observable = ObservableDeferred(origin_d, consumeErrors=True)
observer1 = observable.observe() observer1 = observable.observe()
@ -83,16 +86,16 @@ class ObservableDeferredTest(TestCase):
self.assertFalse(observer2.called) self.assertFalse(observer2.called)
# check the first observer is called first # check the first observer is called first
def check_called_first(res): def check_called_first(res: int) -> int:
self.assertFalse(observer2.called) self.assertFalse(observer2.called)
return res return res
observer1.addBoth(check_called_first) observer1.addBoth(check_called_first)
# store the results # store the results
results = [None, None] results: List[Optional[ObservableDeferred[str]]] = [None, None]
def check_val(res, idx): def check_val(res: ObservableDeferred[str], idx: int) -> None:
results[idx] = res results[idx] = res
return None return None
@ -103,10 +106,12 @@ class ObservableDeferredTest(TestCase):
raise Exception("gah!") raise Exception("gah!")
except Exception as e: except Exception as e:
origin_d.errback(e) origin_d.errback(e)
assert results[0] is not None
self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result") self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
assert results[1] is not None
self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result") self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")
def test_cancellation(self): def test_cancellation(self) -> None:
"""Test that cancelling an observer does not affect other observers.""" """Test that cancelling an observer does not affect other observers."""
origin_d: "Deferred[int]" = Deferred() origin_d: "Deferred[int]" = Deferred()
observable = ObservableDeferred(origin_d, consumeErrors=True) observable = ObservableDeferred(origin_d, consumeErrors=True)
@ -136,37 +141,38 @@ class ObservableDeferredTest(TestCase):
class TimeoutDeferredTest(TestCase): class TimeoutDeferredTest(TestCase):
def setUp(self): def setUp(self) -> None:
self.clock = Clock() self.clock = Clock()
def test_times_out(self): def test_times_out(self) -> None:
"""Basic test case that checks that the original deferred is cancelled and that """Basic test case that checks that the original deferred is cancelled and that
the timing-out deferred is errbacked the timing-out deferred is errbacked
""" """
cancelled = [False] cancelled = False
def canceller(_d): def canceller(_d: Deferred) -> None:
cancelled[0] = True nonlocal cancelled
cancelled = True
non_completing_d = Deferred(canceller) non_completing_d: Deferred = Deferred(canceller)
timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock) timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
self.assertNoResult(timing_out_d) self.assertNoResult(timing_out_d)
self.assertFalse(cancelled[0], "deferred was cancelled prematurely") self.assertFalse(cancelled, "deferred was cancelled prematurely")
self.clock.pump((1.0,)) self.clock.pump((1.0,))
self.assertTrue(cancelled[0], "deferred was not cancelled by timeout") self.assertTrue(cancelled, "deferred was not cancelled by timeout")
self.failureResultOf(timing_out_d, defer.TimeoutError) self.failureResultOf(timing_out_d, defer.TimeoutError)
def test_times_out_when_canceller_throws(self): def test_times_out_when_canceller_throws(self) -> None:
"""Test that we have successfully worked around """Test that we have successfully worked around
https://twistedmatrix.com/trac/ticket/9534""" https://twistedmatrix.com/trac/ticket/9534"""
def canceller(_d): def canceller(_d: Deferred) -> None:
raise Exception("can't cancel this deferred") raise Exception("can't cancel this deferred")
non_completing_d = Deferred(canceller) non_completing_d: Deferred = Deferred(canceller)
timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock) timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
self.assertNoResult(timing_out_d) self.assertNoResult(timing_out_d)
@ -175,22 +181,24 @@ class TimeoutDeferredTest(TestCase):
self.failureResultOf(timing_out_d, defer.TimeoutError) self.failureResultOf(timing_out_d, defer.TimeoutError)
def test_logcontext_is_preserved_on_cancellation(self): def test_logcontext_is_preserved_on_cancellation(self) -> None:
blocking_was_cancelled = [False] blocking_was_cancelled = False
@defer.inlineCallbacks @defer.inlineCallbacks
def blocking(): def blocking() -> Generator["Deferred[object]", object, None]:
non_completing_d = Deferred() nonlocal blocking_was_cancelled
non_completing_d: Deferred = Deferred()
with PreserveLoggingContext(): with PreserveLoggingContext():
try: try:
yield non_completing_d yield non_completing_d
except CancelledError: except CancelledError:
blocking_was_cancelled[0] = True blocking_was_cancelled = True
raise raise
with LoggingContext("one") as context_one: with LoggingContext("one") as context_one:
# the errbacks should be run in the test logcontext # the errbacks should be run in the test logcontext
def errback(res, deferred_name): def errback(res: Failure, deferred_name: str) -> Failure:
self.assertIs( self.assertIs(
current_context(), current_context(),
context_one, context_one,
@ -209,7 +217,7 @@ class TimeoutDeferredTest(TestCase):
self.clock.pump((1.0,)) self.clock.pump((1.0,))
self.assertTrue( self.assertTrue(
blocking_was_cancelled[0], "non-completing deferred was not cancelled" blocking_was_cancelled, "non-completing deferred was not cancelled"
) )
self.failureResultOf(timing_out_d, defer.TimeoutError) self.failureResultOf(timing_out_d, defer.TimeoutError)
self.assertIs(current_context(), context_one) self.assertIs(current_context(), context_one)
@ -220,13 +228,13 @@ class _TestException(Exception):
class ConcurrentlyExecuteTest(TestCase): class ConcurrentlyExecuteTest(TestCase):
def test_limits_runners(self): def test_limits_runners(self) -> None:
"""If we have more tasks than runners, we should get the limit of runners""" """If we have more tasks than runners, we should get the limit of runners"""
started = 0 started = 0
waiters = [] waiters = []
processed = [] processed = []
async def callback(v): async def callback(v: int) -> None:
# when we first enter, bump the start count # when we first enter, bump the start count
nonlocal started nonlocal started
started += 1 started += 1
@ -235,7 +243,7 @@ class ConcurrentlyExecuteTest(TestCase):
processed.append(v) processed.append(v)
# wait for the goahead before returning # wait for the goahead before returning
d2 = Deferred() d2: "Deferred[int]" = Deferred()
waiters.append(d2) waiters.append(d2)
await d2 await d2
@ -265,16 +273,16 @@ class ConcurrentlyExecuteTest(TestCase):
self.assertCountEqual(processed, [1, 2, 3, 4, 5]) self.assertCountEqual(processed, [1, 2, 3, 4, 5])
self.successResultOf(d2) self.successResultOf(d2)
def test_preserves_stacktraces(self): def test_preserves_stacktraces(self) -> None:
"""Test that the stacktrace from an exception thrown in the callback is preserved""" """Test that the stacktrace from an exception thrown in the callback is preserved"""
d1 = Deferred() d1: "Deferred[int]" = Deferred()
async def callback(v): async def callback(v: int) -> None:
# alas, this doesn't work at all without an await here # alas, this doesn't work at all without an await here
await d1 await d1
raise _TestException("bah") raise _TestException("bah")
async def caller(): async def caller() -> None:
try: try:
await concurrently_execute(callback, [1], 2) await concurrently_execute(callback, [1], 2)
except _TestException as e: except _TestException as e:
@ -290,17 +298,17 @@ class ConcurrentlyExecuteTest(TestCase):
d1.callback(0) d1.callback(0)
self.successResultOf(d2) self.successResultOf(d2)
def test_preserves_stacktraces_on_preformed_failure(self): def test_preserves_stacktraces_on_preformed_failure(self) -> None:
"""Test that the stacktrace on a Failure returned by the callback is preserved""" """Test that the stacktrace on a Failure returned by the callback is preserved"""
d1 = Deferred() d1: "Deferred[int]" = Deferred()
f = Failure(_TestException("bah")) f = Failure(_TestException("bah"))
async def callback(v): async def callback(v: int) -> None:
# alas, this doesn't work at all without an await here # alas, this doesn't work at all without an await here
await d1 await d1
await defer.fail(f) await defer.fail(f)
async def caller(): async def caller() -> None:
try: try:
await concurrently_execute(callback, [1], 2) await concurrently_execute(callback, [1], 2)
except _TestException as e: except _TestException as e:
@ -336,7 +344,7 @@ class CancellationWrapperTests(TestCase):
else: else:
raise ValueError(f"Unsupported wrapper type: {self.wrapper}") raise ValueError(f"Unsupported wrapper type: {self.wrapper}")
def test_succeed(self): def test_succeed(self) -> None:
"""Test that the new `Deferred` receives the result.""" """Test that the new `Deferred` receives the result."""
deferred: "Deferred[str]" = Deferred() deferred: "Deferred[str]" = Deferred()
wrapper_deferred = self.wrap_deferred(deferred) wrapper_deferred = self.wrap_deferred(deferred)
@ -346,7 +354,7 @@ class CancellationWrapperTests(TestCase):
self.assertTrue(wrapper_deferred.called) self.assertTrue(wrapper_deferred.called)
self.assertEqual("success", self.successResultOf(wrapper_deferred)) self.assertEqual("success", self.successResultOf(wrapper_deferred))
def test_failure(self): def test_failure(self) -> None:
"""Test that the new `Deferred` receives the `Failure`.""" """Test that the new `Deferred` receives the `Failure`."""
deferred: "Deferred[str]" = Deferred() deferred: "Deferred[str]" = Deferred()
wrapper_deferred = self.wrap_deferred(deferred) wrapper_deferred = self.wrap_deferred(deferred)
@ -361,7 +369,7 @@ class CancellationWrapperTests(TestCase):
class StopCancellationTests(TestCase): class StopCancellationTests(TestCase):
"""Tests for the `stop_cancellation` function.""" """Tests for the `stop_cancellation` function."""
def test_cancellation(self): def test_cancellation(self) -> None:
"""Test that cancellation of the new `Deferred` leaves the original running.""" """Test that cancellation of the new `Deferred` leaves the original running."""
deferred: "Deferred[str]" = Deferred() deferred: "Deferred[str]" = Deferred()
wrapper_deferred = stop_cancellation(deferred) wrapper_deferred = stop_cancellation(deferred)
@ -384,7 +392,7 @@ class StopCancellationTests(TestCase):
class DelayCancellationTests(TestCase): class DelayCancellationTests(TestCase):
"""Tests for the `delay_cancellation` function.""" """Tests for the `delay_cancellation` function."""
def test_deferred_cancellation(self): def test_deferred_cancellation(self) -> None:
"""Test that cancellation of the new `Deferred` waits for the original.""" """Test that cancellation of the new `Deferred` waits for the original."""
deferred: "Deferred[str]" = Deferred() deferred: "Deferred[str]" = Deferred()
wrapper_deferred = delay_cancellation(deferred) wrapper_deferred = delay_cancellation(deferred)
@ -405,12 +413,12 @@ class DelayCancellationTests(TestCase):
# Now that the original `Deferred` has failed, we should get a `CancelledError`. # Now that the original `Deferred` has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError) self.failureResultOf(wrapper_deferred, CancelledError)
def test_coroutine_cancellation(self): def test_coroutine_cancellation(self) -> None:
"""Test that cancellation of the new `Deferred` waits for the original.""" """Test that cancellation of the new `Deferred` waits for the original."""
blocking_deferred: "Deferred[None]" = Deferred() blocking_deferred: "Deferred[None]" = Deferred()
completion_deferred: "Deferred[None]" = Deferred() completion_deferred: "Deferred[None]" = Deferred()
async def task(): async def task() -> NoReturn:
await blocking_deferred await blocking_deferred
completion_deferred.callback(None) completion_deferred.callback(None)
# Raise an exception. Twisted should consume it, otherwise unwanted # Raise an exception. Twisted should consume it, otherwise unwanted
@ -434,7 +442,7 @@ class DelayCancellationTests(TestCase):
# Now that the original coroutine has failed, we should get a `CancelledError`. # Now that the original coroutine has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError) self.failureResultOf(wrapper_deferred, CancelledError)
def test_suppresses_second_cancellation(self): def test_suppresses_second_cancellation(self) -> None:
"""Test that a second cancellation is suppressed. """Test that a second cancellation is suppressed.
Identical to `test_cancellation` except the new `Deferred` is cancelled twice. Identical to `test_cancellation` except the new `Deferred` is cancelled twice.
@ -459,7 +467,7 @@ class DelayCancellationTests(TestCase):
# Now that the original `Deferred` has failed, we should get a `CancelledError`. # Now that the original `Deferred` has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError) self.failureResultOf(wrapper_deferred, CancelledError)
def test_propagates_cancelled_error(self): def test_propagates_cancelled_error(self) -> None:
"""Test that a `CancelledError` from the original `Deferred` gets propagated.""" """Test that a `CancelledError` from the original `Deferred` gets propagated."""
deferred: "Deferred[str]" = Deferred() deferred: "Deferred[str]" = Deferred()
wrapper_deferred = delay_cancellation(deferred) wrapper_deferred = delay_cancellation(deferred)
@ -472,14 +480,14 @@ class DelayCancellationTests(TestCase):
self.assertTrue(wrapper_deferred.called) self.assertTrue(wrapper_deferred.called)
self.assertIs(cancelled_error, self.failureResultOf(wrapper_deferred).value) self.assertIs(cancelled_error, self.failureResultOf(wrapper_deferred).value)
def test_preserves_logcontext(self): def test_preserves_logcontext(self) -> None:
"""Test that logging contexts are preserved.""" """Test that logging contexts are preserved."""
blocking_d: "Deferred[None]" = Deferred() blocking_d: "Deferred[None]" = Deferred()
async def inner(): async def inner() -> None:
await make_deferred_yieldable(blocking_d) await make_deferred_yieldable(blocking_d)
async def outer(): async def outer() -> None:
with LoggingContext("c") as c: with LoggingContext("c") as c:
try: try:
await delay_cancellation(inner()) await delay_cancellation(inner())
@ -503,7 +511,7 @@ class DelayCancellationTests(TestCase):
class AwakenableSleeperTests(TestCase): class AwakenableSleeperTests(TestCase):
"Tests AwakenableSleeper" "Tests AwakenableSleeper"
def test_sleep(self): def test_sleep(self) -> None:
reactor, _ = get_clock() reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor) sleeper = AwakenableSleeper(reactor)
@ -518,7 +526,7 @@ class AwakenableSleeperTests(TestCase):
reactor.advance(0.6) reactor.advance(0.6)
self.assertTrue(d.called) self.assertTrue(d.called)
def test_explicit_wake(self): def test_explicit_wake(self) -> None:
reactor, _ = get_clock() reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor) sleeper = AwakenableSleeper(reactor)
@ -535,7 +543,7 @@ class AwakenableSleeperTests(TestCase):
reactor.advance(0.6) reactor.advance(0.6)
def test_multiple_sleepers_timeout(self): def test_multiple_sleepers_timeout(self) -> None:
reactor, _ = get_clock() reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor) sleeper = AwakenableSleeper(reactor)
@ -555,7 +563,7 @@ class AwakenableSleeperTests(TestCase):
reactor.advance(0.6) reactor.advance(0.6)
self.assertTrue(d2.called) self.assertTrue(d2.called)
def test_multiple_sleepers_wake(self): def test_multiple_sleepers_wake(self) -> None:
reactor, _ = get_clock() reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor) sleeper = AwakenableSleeper(reactor)

View file

@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Tuple
from prometheus_client import Gauge
from twisted.internet import defer from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
@ -26,7 +30,7 @@ from tests.unittest import TestCase
class BatchingQueueTestCase(TestCase): class BatchingQueueTestCase(TestCase):
def setUp(self): def setUp(self) -> None:
self.clock, hs_clock = get_clock() self.clock, hs_clock = get_clock()
# We ensure that we remove any existing metrics for "test_queue". # We ensure that we remove any existing metrics for "test_queue".
@ -37,25 +41,27 @@ class BatchingQueueTestCase(TestCase):
except KeyError: except KeyError:
pass pass
self._pending_calls = [] self._pending_calls: List[Tuple[List[str], defer.Deferred]] = []
self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue) self.queue: BatchingQueue[str, str] = BatchingQueue(
"test_queue", hs_clock, self._process_queue
)
async def _process_queue(self, values): async def _process_queue(self, values: List[str]) -> str:
d = defer.Deferred() d: "defer.Deferred[str]" = defer.Deferred()
self._pending_calls.append((values, d)) self._pending_calls.append((values, d))
return await make_deferred_yieldable(d) return await make_deferred_yieldable(d)
def _get_sample_with_name(self, metric, name) -> int: def _get_sample_with_name(self, metric: Gauge, name: str) -> float:
"""For a prometheus metric get the value of the sample that has a """For a prometheus metric get the value of the sample that has a
matching "name" label. matching "name" label.
""" """
for sample in metric.collect()[0].samples: for sample in next(iter(metric.collect())).samples:
if sample.labels.get("name") == name: if sample.labels.get("name") == name:
return sample.value return sample.value
self.fail("Found no matching sample") self.fail("Found no matching sample")
def _assert_metrics(self, queued, keys, in_flight): def _assert_metrics(self, queued: int, keys: int, in_flight: int) -> None:
"""Assert that the metrics are correct""" """Assert that the metrics are correct"""
sample = self._get_sample_with_name(number_queued, self.queue._name) sample = self._get_sample_with_name(number_queued, self.queue._name)
@ -75,7 +81,7 @@ class BatchingQueueTestCase(TestCase):
"number_in_flight", "number_in_flight",
) )
def test_simple(self): def test_simple(self) -> None:
"""Tests the basic case of calling `add_to_queue` once and having """Tests the basic case of calling `add_to_queue` once and having
`_process_queue` return. `_process_queue` return.
""" """
@ -106,7 +112,7 @@ class BatchingQueueTestCase(TestCase):
self._assert_metrics(queued=0, keys=0, in_flight=0) self._assert_metrics(queued=0, keys=0, in_flight=0)
def test_batching(self): def test_batching(self) -> None:
"""Test that multiple calls at the same time get batched up into one """Test that multiple calls at the same time get batched up into one
call to `_process_queue`. call to `_process_queue`.
""" """
@ -134,7 +140,7 @@ class BatchingQueueTestCase(TestCase):
self.assertEqual(self.successResultOf(queue_d2), "bar") self.assertEqual(self.successResultOf(queue_d2), "bar")
self._assert_metrics(queued=0, keys=0, in_flight=0) self._assert_metrics(queued=0, keys=0, in_flight=0)
def test_queuing(self): def test_queuing(self) -> None:
"""Test that we queue up requests while a `_process_queue` is being """Test that we queue up requests while a `_process_queue` is being
called. called.
""" """
@ -184,7 +190,7 @@ class BatchingQueueTestCase(TestCase):
self.assertEqual(self.successResultOf(queue_d3), "bar2") self.assertEqual(self.successResultOf(queue_d3), "bar2")
self._assert_metrics(queued=0, keys=0, in_flight=0) self._assert_metrics(queued=0, keys=0, in_flight=0)
def test_different_keys(self): def test_different_keys(self) -> None:
"""Test that calls to different keys get processed in parallel.""" """Test that calls to different keys get processed in parallel."""
self.assertFalse(self._pending_calls) self.assertFalse(self._pending_calls)

View file

@ -1,5 +1,20 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager from contextlib import contextmanager
from typing import Generator, Optional from os import PathLike
from typing import Generator, Optional, Union
from unittest.mock import patch from unittest.mock import patch
from synapse.util.check_dependencies import ( from synapse.util.check_dependencies import (
@ -12,17 +27,17 @@ from tests.unittest import TestCase
class DummyDistribution(metadata.Distribution): class DummyDistribution(metadata.Distribution):
def __init__(self, version: object): def __init__(self, version: str):
self._version = version self._version = version
@property @property
def version(self): def version(self) -> str:
return self._version return self._version
def locate_file(self, path): def locate_file(self, path: Union[str, PathLike]) -> PathLike:
raise NotImplementedError() raise NotImplementedError()
def read_text(self, filename): def read_text(self, filename: str) -> None:
raise NotImplementedError() raise NotImplementedError()
@ -30,7 +45,7 @@ old = DummyDistribution("0.1.2")
old_release_candidate = DummyDistribution("0.1.2rc3") old_release_candidate = DummyDistribution("0.1.2rc3")
new = DummyDistribution("1.2.3") new = DummyDistribution("1.2.3")
new_release_candidate = DummyDistribution("1.2.3rc4") new_release_candidate = DummyDistribution("1.2.3rc4")
distribution_with_no_version = DummyDistribution(None) distribution_with_no_version = DummyDistribution(None) # type: ignore[arg-type]
# could probably use stdlib TestCase --- no need for twisted here # could probably use stdlib TestCase --- no need for twisted here
@ -45,7 +60,7 @@ class TestDependencyChecker(TestCase):
If `distribution = None`, we pretend that the package is not installed. If `distribution = None`, we pretend that the package is not installed.
""" """
def mock_distribution(name: str): def mock_distribution(name: str) -> DummyDistribution:
if distribution is None: if distribution is None:
raise metadata.PackageNotFoundError raise metadata.PackageNotFoundError
else: else:

View file

@ -19,10 +19,12 @@ from tests import unittest
class DictCacheTestCase(unittest.TestCase): class DictCacheTestCase(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
self.cache = DictionaryCache("foobar", max_entries=10) self.cache: DictionaryCache[str, str, str] = DictionaryCache(
"foobar", max_entries=10
)
def test_simple_cache_hit_full(self): def test_simple_cache_hit_full(self) -> None:
key = "test_simple_cache_hit_full" key = "test_simple_cache_hit_full"
v = self.cache.get(key) v = self.cache.get(key)
@ -37,7 +39,7 @@ class DictCacheTestCase(unittest.TestCase):
c = self.cache.get(key) c = self.cache.get(key)
self.assertEqual(test_value, c.value) self.assertEqual(test_value, c.value)
def test_simple_cache_hit_partial(self): def test_simple_cache_hit_partial(self) -> None:
key = "test_simple_cache_hit_partial" key = "test_simple_cache_hit_partial"
seq = self.cache.sequence seq = self.cache.sequence
@ -47,7 +49,7 @@ class DictCacheTestCase(unittest.TestCase):
c = self.cache.get(key, ["test"]) c = self.cache.get(key, ["test"])
self.assertEqual(test_value, c.value) self.assertEqual(test_value, c.value)
def test_simple_cache_miss_partial(self): def test_simple_cache_miss_partial(self) -> None:
key = "test_simple_cache_miss_partial" key = "test_simple_cache_miss_partial"
seq = self.cache.sequence seq = self.cache.sequence
@ -57,7 +59,7 @@ class DictCacheTestCase(unittest.TestCase):
c = self.cache.get(key, ["test2"]) c = self.cache.get(key, ["test2"])
self.assertEqual({}, c.value) self.assertEqual({}, c.value)
def test_simple_cache_hit_miss_partial(self): def test_simple_cache_hit_miss_partial(self) -> None:
key = "test_simple_cache_hit_miss_partial" key = "test_simple_cache_hit_miss_partial"
seq = self.cache.sequence seq = self.cache.sequence
@ -71,7 +73,7 @@ class DictCacheTestCase(unittest.TestCase):
c = self.cache.get(key, ["test2"]) c = self.cache.get(key, ["test2"])
self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value) self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value)
def test_multi_insert(self): def test_multi_insert(self) -> None:
key = "test_simple_cache_hit_miss_partial" key = "test_simple_cache_hit_miss_partial"
seq = self.cache.sequence seq = self.cache.sequence
@ -92,7 +94,7 @@ class DictCacheTestCase(unittest.TestCase):
) )
self.assertEqual(c.full, False) self.assertEqual(c.full, False)
def test_invalidation(self): def test_invalidation(self) -> None:
"""Test that the partial dict and full dicts get invalidated """Test that the partial dict and full dicts get invalidated
separately. separately.
""" """
@ -106,7 +108,7 @@ class DictCacheTestCase(unittest.TestCase):
# entry for "a" warm. # entry for "a" warm.
for i in range(20): for i in range(20):
self.cache.get(key, ["a"]) self.cache.get(key, ["a"])
self.cache.update(seq, f"key{i}", {1: 2}) self.cache.update(seq, f"key{i}", {"1": "2"})
# We should have evicted the full dict... # We should have evicted the full dict...
r = self.cache.get(key) r = self.cache.get(key)

View file

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, cast
from synapse.util import Clock
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from tests.utils import MockClock from tests.utils import MockClock
@ -21,17 +23,21 @@ from .. import unittest
class ExpiringCacheTestCase(unittest.HomeserverTestCase): class ExpiringCacheTestCase(unittest.HomeserverTestCase):
def test_get_set(self): def test_get_set(self) -> None:
clock = MockClock() clock = MockClock()
cache = ExpiringCache("test", clock, max_len=1) cache: ExpiringCache[str, str] = ExpiringCache(
"test", cast(Clock, clock), max_len=1
)
cache["key"] = "value" cache["key"] = "value"
self.assertEqual(cache.get("key"), "value") self.assertEqual(cache.get("key"), "value")
self.assertEqual(cache["key"], "value") self.assertEqual(cache["key"], "value")
def test_eviction(self): def test_eviction(self) -> None:
clock = MockClock() clock = MockClock()
cache = ExpiringCache("test", clock, max_len=2) cache: ExpiringCache[str, str] = ExpiringCache(
"test", cast(Clock, clock), max_len=2
)
cache["key"] = "value" cache["key"] = "value"
cache["key2"] = "value2" cache["key2"] = "value2"
@ -43,9 +49,11 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get("key2"), "value2") self.assertEqual(cache.get("key2"), "value2")
self.assertEqual(cache.get("key3"), "value3") self.assertEqual(cache.get("key3"), "value3")
def test_iterable_eviction(self): def test_iterable_eviction(self) -> None:
clock = MockClock() clock = MockClock()
cache = ExpiringCache("test", clock, max_len=5, iterable=True) cache: ExpiringCache[str, List[int]] = ExpiringCache(
"test", cast(Clock, clock), max_len=5, iterable=True
)
cache["key"] = [1] cache["key"] = [1]
cache["key2"] = [2, 3] cache["key2"] = [2, 3]
@ -61,9 +69,11 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get("key3"), [4, 5]) self.assertEqual(cache.get("key3"), [4, 5])
self.assertEqual(cache.get("key4"), [6, 7]) self.assertEqual(cache.get("key4"), [6, 7])
def test_time_eviction(self): def test_time_eviction(self) -> None:
clock = MockClock() clock = MockClock()
cache = ExpiringCache("test", clock, expiry_ms=1000) cache: ExpiringCache[str, int] = ExpiringCache(
"test", cast(Clock, clock), expiry_ms=1000
)
cache["key"] = 1 cache["key"] = 1
clock.advance_time(0.5) clock.advance_time(0.5)

View file

@ -12,22 +12,28 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import threading import threading
from io import StringIO from io import BytesIO
from typing import BinaryIO, Generator, Optional, cast
from unittest.mock import NonCallableMock from unittest.mock import NonCallableMock
from twisted.internet import defer, reactor from zope.interface import implementer
from twisted.internet import defer, reactor as _reactor
from twisted.internet.interfaces import IPullProducer
from synapse.types import ISynapseReactor
from synapse.util.file_consumer import BackgroundFileConsumer from synapse.util.file_consumer import BackgroundFileConsumer
from tests import unittest from tests import unittest
reactor = cast(ISynapseReactor, _reactor)
class FileConsumerTests(unittest.TestCase): class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_pull_consumer(self): def test_pull_consumer(self) -> Generator["defer.Deferred[object]", object, None]:
string_file = StringIO() string_file = BytesIO()
consumer = BackgroundFileConsumer(string_file, reactor=reactor) consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try: try:
@ -35,55 +41,57 @@ class FileConsumerTests(unittest.TestCase):
yield producer.register_with_consumer(consumer) yield producer.register_with_consumer(consumer)
yield producer.write_and_wait("Foo") yield producer.write_and_wait(b"Foo")
self.assertEqual(string_file.getvalue(), "Foo") self.assertEqual(string_file.getvalue(), b"Foo")
yield producer.write_and_wait("Bar") yield producer.write_and_wait(b"Bar")
self.assertEqual(string_file.getvalue(), "FooBar") self.assertEqual(string_file.getvalue(), b"FooBar")
finally: finally:
consumer.unregisterProducer() consumer.unregisterProducer()
yield consumer.wait() yield consumer.wait() # type: ignore[misc]
self.assertTrue(string_file.closed) self.assertTrue(string_file.closed)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_push_consumer(self): def test_push_consumer(self) -> Generator["defer.Deferred[object]", object, None]:
string_file = BlockingStringWrite() string_file = BlockingBytesWrite()
consumer = BackgroundFileConsumer(string_file, reactor=reactor) consumer = BackgroundFileConsumer(cast(BinaryIO, string_file), reactor=reactor)
try: try:
producer = NonCallableMock(spec_set=[]) producer = NonCallableMock(spec_set=[])
consumer.registerProducer(producer, True) consumer.registerProducer(producer, True)
consumer.write("Foo") consumer.write(b"Foo")
yield string_file.wait_for_n_writes(1) yield string_file.wait_for_n_writes(1) # type: ignore[misc]
self.assertEqual(string_file.buffer, "Foo") self.assertEqual(string_file.buffer, b"Foo")
consumer.write("Bar") consumer.write(b"Bar")
yield string_file.wait_for_n_writes(2) yield string_file.wait_for_n_writes(2) # type: ignore[misc]
self.assertEqual(string_file.buffer, "FooBar") self.assertEqual(string_file.buffer, b"FooBar")
finally: finally:
consumer.unregisterProducer() consumer.unregisterProducer()
yield consumer.wait() yield consumer.wait() # type: ignore[misc]
self.assertTrue(string_file.closed) self.assertTrue(string_file.closed)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_push_producer_feedback(self): def test_push_producer_feedback(
string_file = BlockingStringWrite() self,
consumer = BackgroundFileConsumer(string_file, reactor=reactor) ) -> Generator["defer.Deferred[object]", object, None]:
string_file = BlockingBytesWrite()
consumer = BackgroundFileConsumer(cast(BinaryIO, string_file), reactor=reactor)
try: try:
producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"]) producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
resume_deferred = defer.Deferred() resume_deferred: defer.Deferred = defer.Deferred()
producer.resumeProducing.side_effect = lambda: resume_deferred.callback( producer.resumeProducing.side_effect = lambda: resume_deferred.callback(
None None
) )
@ -93,65 +101,72 @@ class FileConsumerTests(unittest.TestCase):
number_writes = 0 number_writes = 0
with string_file.write_lock: with string_file.write_lock:
for _ in range(consumer._PAUSE_ON_QUEUE_SIZE): for _ in range(consumer._PAUSE_ON_QUEUE_SIZE):
consumer.write("Foo") consumer.write(b"Foo")
number_writes += 1 number_writes += 1
producer.pauseProducing.assert_called_once() producer.pauseProducing.assert_called_once()
yield string_file.wait_for_n_writes(number_writes) yield string_file.wait_for_n_writes(number_writes) # type: ignore[misc]
yield resume_deferred yield resume_deferred
producer.resumeProducing.assert_called_once() producer.resumeProducing.assert_called_once()
finally: finally:
consumer.unregisterProducer() consumer.unregisterProducer()
yield consumer.wait() yield consumer.wait() # type: ignore[misc]
self.assertTrue(string_file.closed) self.assertTrue(string_file.closed)
@implementer(IPullProducer)
class DummyPullProducer: class DummyPullProducer:
def __init__(self): def __init__(self) -> None:
self.consumer = None self.consumer: Optional[BackgroundFileConsumer] = None
self.deferred = defer.Deferred() self.deferred: "defer.Deferred[object]" = defer.Deferred()
def resumeProducing(self): def resumeProducing(self) -> None:
d = self.deferred d = self.deferred
self.deferred = defer.Deferred() self.deferred = defer.Deferred()
d.callback(None) d.callback(None)
def write_and_wait(self, bytes): def stopProducing(self) -> None:
raise RuntimeError("Unexpected call")
def write_and_wait(self, write_bytes: bytes) -> "defer.Deferred[object]":
assert self.consumer is not None
d = self.deferred d = self.deferred
self.consumer.write(bytes) self.consumer.write(write_bytes)
return d return d
def register_with_consumer(self, consumer): def register_with_consumer(
self, consumer: BackgroundFileConsumer
) -> "defer.Deferred[object]":
d = self.deferred d = self.deferred
self.consumer = consumer self.consumer = consumer
self.consumer.registerProducer(self, False) self.consumer.registerProducer(self, False)
return d return d
class BlockingStringWrite: class BlockingBytesWrite:
def __init__(self): def __init__(self) -> None:
self.buffer = "" self.buffer = b""
self.closed = False self.closed = False
self.write_lock = threading.Lock() self.write_lock = threading.Lock()
self._notify_write_deferred = None self._notify_write_deferred: Optional[defer.Deferred] = None
self._number_of_writes = 0 self._number_of_writes = 0
def write(self, bytes): def write(self, write_bytes: bytes) -> None:
with self.write_lock: with self.write_lock:
self.buffer += bytes self.buffer += write_bytes
self._number_of_writes += 1 self._number_of_writes += 1
reactor.callFromThread(self._notify_write) reactor.callFromThread(self._notify_write)
def close(self): def close(self) -> None:
self.closed = True self.closed = True
def _notify_write(self): def _notify_write(self) -> None:
"Called by write to indicate a write happened" "Called by write to indicate a write happened"
with self.write_lock: with self.write_lock:
if not self._notify_write_deferred: if not self._notify_write_deferred:
@ -161,7 +176,9 @@ class BlockingStringWrite:
d.callback(None) d.callback(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_n_writes(self, n): def wait_for_n_writes(
self, n: int
) -> Generator["defer.Deferred[object]", object, None]:
"Wait for n writes to have happened" "Wait for n writes to have happened"
while True: while True:
with self.write_lock: with self.write_lock:

View file

@ -19,7 +19,7 @@ from tests.unittest import TestCase
class ChunkSeqTests(TestCase): class ChunkSeqTests(TestCase):
def test_short_seq(self): def test_short_seq(self) -> None:
parts = chunk_seq("123", 8) parts = chunk_seq("123", 8)
self.assertEqual( self.assertEqual(
@ -27,7 +27,7 @@ class ChunkSeqTests(TestCase):
["123"], ["123"],
) )
def test_long_seq(self): def test_long_seq(self) -> None:
parts = chunk_seq("abcdefghijklmnop", 8) parts = chunk_seq("abcdefghijklmnop", 8)
self.assertEqual( self.assertEqual(
@ -35,7 +35,7 @@ class ChunkSeqTests(TestCase):
["abcdefgh", "ijklmnop"], ["abcdefgh", "ijklmnop"],
) )
def test_uneven_parts(self): def test_uneven_parts(self) -> None:
parts = chunk_seq("abcdefghijklmnop", 5) parts = chunk_seq("abcdefghijklmnop", 5)
self.assertEqual( self.assertEqual(
@ -43,7 +43,7 @@ class ChunkSeqTests(TestCase):
["abcde", "fghij", "klmno", "p"], ["abcde", "fghij", "klmno", "p"],
) )
def test_empty_input(self): def test_empty_input(self) -> None:
parts: Iterable[Sequence] = chunk_seq([], 5) parts: Iterable[Sequence] = chunk_seq([], 5)
self.assertEqual( self.assertEqual(
@ -53,13 +53,13 @@ class ChunkSeqTests(TestCase):
class SortTopologically(TestCase): class SortTopologically(TestCase):
def test_empty(self): def test_empty(self) -> None:
"Test that an empty graph works correctly" "Test that an empty graph works correctly"
graph: Dict[int, List[int]] = {} graph: Dict[int, List[int]] = {}
self.assertEqual(list(sorted_topologically([], graph)), []) self.assertEqual(list(sorted_topologically([], graph)), [])
def test_handle_empty_graph(self): def test_handle_empty_graph(self) -> None:
"Test that a graph where a node doesn't have an entry is treated as empty" "Test that a graph where a node doesn't have an entry is treated as empty"
graph: Dict[int, List[int]] = {} graph: Dict[int, List[int]] = {}
@ -67,7 +67,7 @@ class SortTopologically(TestCase):
# For disconnected nodes the output is simply sorted. # For disconnected nodes the output is simply sorted.
self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
def test_disconnected(self): def test_disconnected(self) -> None:
"Test that a graph with no edges work" "Test that a graph with no edges work"
graph: Dict[int, List[int]] = {1: [], 2: []} graph: Dict[int, List[int]] = {1: [], 2: []}
@ -75,20 +75,20 @@ class SortTopologically(TestCase):
# For disconnected nodes the output is simply sorted. # For disconnected nodes the output is simply sorted.
self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
def test_linear(self): def test_linear(self) -> None:
"Test that a simple `4 -> 3 -> 2 -> 1` graph works" "Test that a simple `4 -> 3 -> 2 -> 1` graph works"
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]} graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]}
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
def test_subset(self): def test_subset(self) -> None:
"Test that only sorting a subset of the graph works" "Test that only sorting a subset of the graph works"
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]} graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]}
self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4]) self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4])
def test_fork(self): def test_fork(self) -> None:
"Test that a forked graph works" "Test that a forked graph works"
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [1], 4: [2, 3]} graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [1], 4: [2, 3]}
@ -96,13 +96,13 @@ class SortTopologically(TestCase):
# always get the same one. # always get the same one.
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
def test_duplicates(self): def test_duplicates(self) -> None:
"Test that a graph with duplicate edges work" "Test that a graph with duplicate edges work"
graph: Dict[int, List[int]] = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]} graph: Dict[int, List[int]] = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]}
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
def test_multiple_paths(self): def test_multiple_paths(self) -> None:
"Test that a graph with multiple paths between two nodes work" "Test that a graph with multiple paths between two nodes work"
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]} graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]}

View file

@ -1,5 +1,21 @@
# Copyright 2014-2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Generator, cast
import twisted.python.failure import twisted.python.failure
from twisted.internet import defer, reactor from twisted.internet import defer, reactor as _reactor
from synapse.logging.context import ( from synapse.logging.context import (
SENTINEL_CONTEXT, SENTINEL_CONTEXT,
@ -10,25 +26,30 @@ from synapse.logging.context import (
nested_logging_context, nested_logging_context,
run_in_background, run_in_background,
) )
from synapse.types import ISynapseReactor
from synapse.util import Clock from synapse.util import Clock
from .. import unittest from .. import unittest
reactor = cast(ISynapseReactor, _reactor)
class LoggingContextTestCase(unittest.TestCase): class LoggingContextTestCase(unittest.TestCase):
def _check_test_key(self, value): def _check_test_key(self, value: str) -> None:
self.assertEqual(current_context().name, value) context = current_context()
assert isinstance(context, LoggingContext)
self.assertEqual(context.name, value)
def test_with_context(self): def test_with_context(self) -> None:
with LoggingContext("test"): with LoggingContext("test"):
self._check_test_key("test") self._check_test_key("test")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_sleep(self): def test_sleep(self) -> Generator["defer.Deferred[object]", object, None]:
clock = Clock(reactor) clock = Clock(reactor)
@defer.inlineCallbacks @defer.inlineCallbacks
def competing_callback(): def competing_callback() -> Generator["defer.Deferred[object]", object, None]:
with LoggingContext("competing"): with LoggingContext("competing"):
yield clock.sleep(0) yield clock.sleep(0)
self._check_test_key("competing") self._check_test_key("competing")
@ -39,17 +60,18 @@ class LoggingContextTestCase(unittest.TestCase):
yield clock.sleep(0) yield clock.sleep(0)
self._check_test_key("one") self._check_test_key("one")
def _test_run_in_background(self, function): def _test_run_in_background(self, function: Callable[[], object]) -> defer.Deferred:
sentinel_context = current_context() sentinel_context = current_context()
callback_completed = [False] callback_completed = False
with LoggingContext("one"): with LoggingContext("one"):
# fire off function, but don't wait on it. # fire off function, but don't wait on it.
d2 = run_in_background(function) d2 = run_in_background(function)
def cb(res): def cb(res: object) -> object:
callback_completed[0] = True nonlocal callback_completed
callback_completed = True
return res return res
d2.addCallback(cb) d2.addCallback(cb)
@ -60,8 +82,8 @@ class LoggingContextTestCase(unittest.TestCase):
# the logcontext is left in a sane state. # the logcontext is left in a sane state.
d2 = defer.Deferred() d2 = defer.Deferred()
def check_logcontext(): def check_logcontext() -> None:
if not callback_completed[0]: if not callback_completed:
reactor.callLater(0.01, check_logcontext) reactor.callLater(0.01, check_logcontext)
return return
@ -78,31 +100,31 @@ class LoggingContextTestCase(unittest.TestCase):
# test is done once d2 finishes # test is done once d2 finishes
return d2 return d2
def test_run_in_background_with_blocking_fn(self): def test_run_in_background_with_blocking_fn(self) -> defer.Deferred:
@defer.inlineCallbacks @defer.inlineCallbacks
def blocking_function(): def blocking_function() -> Generator["defer.Deferred[object]", object, None]:
yield Clock(reactor).sleep(0) yield Clock(reactor).sleep(0)
return self._test_run_in_background(blocking_function) return self._test_run_in_background(blocking_function)
def test_run_in_background_with_non_blocking_fn(self): def test_run_in_background_with_non_blocking_fn(self) -> defer.Deferred:
@defer.inlineCallbacks @defer.inlineCallbacks
def nonblocking_function(): def nonblocking_function() -> Generator["defer.Deferred[object]", object, None]:
with PreserveLoggingContext(): with PreserveLoggingContext():
yield defer.succeed(None) yield defer.succeed(None)
return self._test_run_in_background(nonblocking_function) return self._test_run_in_background(nonblocking_function)
def test_run_in_background_with_chained_deferred(self): def test_run_in_background_with_chained_deferred(self) -> defer.Deferred:
# a function which returns a deferred which looks like it has been # a function which returns a deferred which looks like it has been
# called, but is actually paused # called, but is actually paused
def testfunc(): def testfunc() -> defer.Deferred:
return make_deferred_yieldable(_chained_deferred_function()) return make_deferred_yieldable(_chained_deferred_function())
return self._test_run_in_background(testfunc) return self._test_run_in_background(testfunc)
def test_run_in_background_with_coroutine(self): def test_run_in_background_with_coroutine(self) -> defer.Deferred:
async def testfunc(): async def testfunc() -> None:
self._check_test_key("one") self._check_test_key("one")
d = Clock(reactor).sleep(0) d = Clock(reactor).sleep(0)
self.assertIs(current_context(), SENTINEL_CONTEXT) self.assertIs(current_context(), SENTINEL_CONTEXT)
@ -111,18 +133,20 @@ class LoggingContextTestCase(unittest.TestCase):
return self._test_run_in_background(testfunc) return self._test_run_in_background(testfunc)
def test_run_in_background_with_nonblocking_coroutine(self): def test_run_in_background_with_nonblocking_coroutine(self) -> defer.Deferred:
async def testfunc(): async def testfunc() -> None:
self._check_test_key("one") self._check_test_key("one")
return self._test_run_in_background(testfunc) return self._test_run_in_background(testfunc)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_make_deferred_yieldable(self): def test_make_deferred_yieldable(
self,
) -> Generator["defer.Deferred[object]", object, None]:
# a function which returns an incomplete deferred, but doesn't follow # a function which returns an incomplete deferred, but doesn't follow
# the synapse rules. # the synapse rules.
def blocking_function(): def blocking_function() -> defer.Deferred:
d = defer.Deferred() d: defer.Deferred = defer.Deferred()
reactor.callLater(0, d.callback, None) reactor.callLater(0, d.callback, None)
return d return d
@ -139,7 +163,9 @@ class LoggingContextTestCase(unittest.TestCase):
self._check_test_key("one") self._check_test_key("one")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_make_deferred_yieldable_with_chained_deferreds(self): def test_make_deferred_yieldable_with_chained_deferreds(
self,
) -> Generator["defer.Deferred[object]", object, None]:
sentinel_context = current_context() sentinel_context = current_context()
with LoggingContext("one"): with LoggingContext("one"):
@ -152,7 +178,7 @@ class LoggingContextTestCase(unittest.TestCase):
# now it should be restored # now it should be restored
self._check_test_key("one") self._check_test_key("one")
def test_nested_logging_context(self): def test_nested_logging_context(self) -> None:
with LoggingContext("foo"): with LoggingContext("foo"):
nested_context = nested_logging_context(suffix="bar") nested_context = nested_logging_context(suffix="bar")
self.assertEqual(nested_context.name, "foo-bar") self.assertEqual(nested_context.name, "foo-bar")
@ -161,11 +187,11 @@ class LoggingContextTestCase(unittest.TestCase):
# a function which returns a deferred which has been "called", but # a function which returns a deferred which has been "called", but
# which had a function which returned another incomplete deferred on # which had a function which returned another incomplete deferred on
# its callback list, so won't yet call any other new callbacks. # its callback list, so won't yet call any other new callbacks.
def _chained_deferred_function(): def _chained_deferred_function() -> defer.Deferred:
d = defer.succeed(None) d = defer.succeed(None)
def cb(res): def cb(res: object) -> defer.Deferred:
d2 = defer.Deferred() d2: defer.Deferred = defer.Deferred()
reactor.callLater(0, d2.callback, res) reactor.callLater(0, d2.callback, res)
return d2 return d2

View file

@ -23,7 +23,7 @@ class TestException(Exception):
class LogFormatterTestCase(unittest.TestCase): class LogFormatterTestCase(unittest.TestCase):
def test_formatter(self): def test_formatter(self) -> None:
formatter = LogFormatter() formatter = LogFormatter()
try: try:

View file

@ -13,10 +13,11 @@
# limitations under the License. # limitations under the License.
from typing import List from typing import List, Tuple
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from synapse.metrics.jemalloc import JemallocStats from synapse.metrics.jemalloc import JemallocStats
from synapse.types import JsonDict
from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries
from synapse.util.caches.treecache import TreeCache from synapse.util.caches.treecache import TreeCache
@ -25,14 +26,14 @@ from tests.unittest import override_config
class LruCacheTestCase(unittest.HomeserverTestCase): class LruCacheTestCase(unittest.HomeserverTestCase):
def test_get_set(self): def test_get_set(self) -> None:
cache = LruCache(1) cache: LruCache[str, str] = LruCache(1)
cache["key"] = "value" cache["key"] = "value"
self.assertEqual(cache.get("key"), "value") self.assertEqual(cache.get("key"), "value")
self.assertEqual(cache["key"], "value") self.assertEqual(cache["key"], "value")
def test_eviction(self): def test_eviction(self) -> None:
cache = LruCache(2) cache: LruCache[int, int] = LruCache(2)
cache[1] = 1 cache[1] = 1
cache[2] = 2 cache[2] = 2
@ -45,8 +46,8 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get(2), 2) self.assertEqual(cache.get(2), 2)
self.assertEqual(cache.get(3), 3) self.assertEqual(cache.get(3), 3)
def test_setdefault(self): def test_setdefault(self) -> None:
cache = LruCache(1) cache: LruCache[str, int] = LruCache(1)
self.assertEqual(cache.setdefault("key", 1), 1) self.assertEqual(cache.setdefault("key", 1), 1)
self.assertEqual(cache.get("key"), 1) self.assertEqual(cache.get("key"), 1)
self.assertEqual(cache.setdefault("key", 2), 1) self.assertEqual(cache.setdefault("key", 2), 1)
@ -54,14 +55,15 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
cache["key"] = 2 # Make sure overriding works. cache["key"] = 2 # Make sure overriding works.
self.assertEqual(cache.get("key"), 2) self.assertEqual(cache.get("key"), 2)
def test_pop(self): def test_pop(self) -> None:
cache = LruCache(1) cache: LruCache[str, int] = LruCache(1)
cache["key"] = 1 cache["key"] = 1
self.assertEqual(cache.pop("key"), 1) self.assertEqual(cache.pop("key"), 1)
self.assertEqual(cache.pop("key"), None) self.assertEqual(cache.pop("key"), None)
def test_del_multi(self): def test_del_multi(self) -> None:
cache = LruCache(4, cache_type=TreeCache) # The type here isn't quite correct as they don't handle TreeCache well.
cache: LruCache[Tuple[str, str], str] = LruCache(4, cache_type=TreeCache)
cache[("animal", "cat")] = "mew" cache[("animal", "cat")] = "mew"
cache[("animal", "dog")] = "woof" cache[("animal", "dog")] = "woof"
cache[("vehicles", "car")] = "vroom" cache[("vehicles", "car")] = "vroom"
@ -71,7 +73,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get(("animal", "cat")), "mew") self.assertEqual(cache.get(("animal", "cat")), "mew")
self.assertEqual(cache.get(("vehicles", "car")), "vroom") self.assertEqual(cache.get(("vehicles", "car")), "vroom")
cache.del_multi(("animal",)) cache.del_multi(("animal",)) # type: ignore[arg-type]
self.assertEqual(len(cache), 2) self.assertEqual(len(cache), 2)
self.assertEqual(cache.get(("animal", "cat")), None) self.assertEqual(cache.get(("animal", "cat")), None)
self.assertEqual(cache.get(("animal", "dog")), None) self.assertEqual(cache.get(("animal", "dog")), None)
@ -79,22 +81,22 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get(("vehicles", "train")), "chuff") self.assertEqual(cache.get(("vehicles", "train")), "chuff")
# Man from del_multi say "Yes". # Man from del_multi say "Yes".
def test_clear(self): def test_clear(self) -> None:
cache = LruCache(1) cache: LruCache[str, int] = LruCache(1)
cache["key"] = 1 cache["key"] = 1
cache.clear() cache.clear()
self.assertEqual(len(cache), 0) self.assertEqual(len(cache), 0)
@override_config({"caches": {"per_cache_factors": {"mycache": 10}}}) @override_config({"caches": {"per_cache_factors": {"mycache": 10}}})
def test_special_size(self): def test_special_size(self) -> None:
cache = LruCache(10, "mycache") cache: LruCache = LruCache(10, "mycache")
self.assertEqual(cache.max_size, 100) self.assertEqual(cache.max_size, 100)
class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_get(self): def test_get(self) -> None:
m = Mock() m = Mock()
cache = LruCache(1) cache: LruCache[str, str] = LruCache(1)
cache.set("key", "value") cache.set("key", "value")
self.assertFalse(m.called) self.assertFalse(m.called)
@ -111,9 +113,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
cache.set("key", "value") cache.set("key", "value")
self.assertEqual(m.call_count, 1) self.assertEqual(m.call_count, 1)
def test_multi_get(self): def test_multi_get(self) -> None:
m = Mock() m = Mock()
cache = LruCache(1) cache: LruCache[str, str] = LruCache(1)
cache.set("key", "value") cache.set("key", "value")
self.assertFalse(m.called) self.assertFalse(m.called)
@ -130,9 +132,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
cache.set("key", "value") cache.set("key", "value")
self.assertEqual(m.call_count, 1) self.assertEqual(m.call_count, 1)
def test_set(self): def test_set(self) -> None:
m = Mock() m = Mock()
cache = LruCache(1) cache: LruCache[str, str] = LruCache(1)
cache.set("key", "value", callbacks=[m]) cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called) self.assertFalse(m.called)
@ -146,9 +148,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
cache.set("key", "value") cache.set("key", "value")
self.assertEqual(m.call_count, 1) self.assertEqual(m.call_count, 1)
def test_pop(self): def test_pop(self) -> None:
m = Mock() m = Mock()
cache = LruCache(1) cache: LruCache[str, str] = LruCache(1)
cache.set("key", "value", callbacks=[m]) cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called) self.assertFalse(m.called)
@ -162,12 +164,13 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
cache.pop("key") cache.pop("key")
self.assertEqual(m.call_count, 1) self.assertEqual(m.call_count, 1)
def test_del_multi(self): def test_del_multi(self) -> None:
m1 = Mock() m1 = Mock()
m2 = Mock() m2 = Mock()
m3 = Mock() m3 = Mock()
m4 = Mock() m4 = Mock()
cache = LruCache(4, cache_type=TreeCache) # The type here isn't quite correct as they don't handle TreeCache well.
cache: LruCache[Tuple[str, str], str] = LruCache(4, cache_type=TreeCache)
cache.set(("a", "1"), "value", callbacks=[m1]) cache.set(("a", "1"), "value", callbacks=[m1])
cache.set(("a", "2"), "value", callbacks=[m2]) cache.set(("a", "2"), "value", callbacks=[m2])
@ -179,17 +182,17 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
self.assertEqual(m3.call_count, 0) self.assertEqual(m3.call_count, 0)
self.assertEqual(m4.call_count, 0) self.assertEqual(m4.call_count, 0)
cache.del_multi(("a",)) cache.del_multi(("a",)) # type: ignore[arg-type]
self.assertEqual(m1.call_count, 1) self.assertEqual(m1.call_count, 1)
self.assertEqual(m2.call_count, 1) self.assertEqual(m2.call_count, 1)
self.assertEqual(m3.call_count, 0) self.assertEqual(m3.call_count, 0)
self.assertEqual(m4.call_count, 0) self.assertEqual(m4.call_count, 0)
def test_clear(self): def test_clear(self) -> None:
m1 = Mock() m1 = Mock()
m2 = Mock() m2 = Mock()
cache = LruCache(5) cache: LruCache[str, str] = LruCache(5)
cache.set("key1", "value", callbacks=[m1]) cache.set("key1", "value", callbacks=[m1])
cache.set("key2", "value", callbacks=[m2]) cache.set("key2", "value", callbacks=[m2])
@ -202,11 +205,11 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
self.assertEqual(m1.call_count, 1) self.assertEqual(m1.call_count, 1)
self.assertEqual(m2.call_count, 1) self.assertEqual(m2.call_count, 1)
def test_eviction(self): def test_eviction(self) -> None:
m1 = Mock(name="m1") m1 = Mock(name="m1")
m2 = Mock(name="m2") m2 = Mock(name="m2")
m3 = Mock(name="m3") m3 = Mock(name="m3")
cache = LruCache(2) cache: LruCache[str, str] = LruCache(2)
cache.set("key1", "value", callbacks=[m1]) cache.set("key1", "value", callbacks=[m1])
cache.set("key2", "value", callbacks=[m2]) cache.set("key2", "value", callbacks=[m2])
@ -241,8 +244,8 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
class LruCacheSizedTestCase(unittest.HomeserverTestCase): class LruCacheSizedTestCase(unittest.HomeserverTestCase):
def test_evict(self): def test_evict(self) -> None:
cache = LruCache(5, size_callback=len) cache: LruCache[str, List[int]] = LruCache(5, size_callback=len)
cache["key1"] = [0] cache["key1"] = [0]
cache["key2"] = [1, 2] cache["key2"] = [1, 2]
cache["key3"] = [3] cache["key3"] = [3]
@ -269,6 +272,7 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase):
cache["key1"] = [] cache["key1"] = []
self.assertEqual(len(cache), 0) self.assertEqual(len(cache), 0)
assert isinstance(cache.cache, dict)
cache.cache["key1"].drop_from_cache() cache.cache["key1"].drop_from_cache()
self.assertIsNone( self.assertIsNone(
cache.pop("key1"), "Cache entry should have been evicted but wasn't" cache.pop("key1"), "Cache entry should have been evicted but wasn't"
@ -278,17 +282,17 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase):
class TimeEvictionTestCase(unittest.HomeserverTestCase): class TimeEvictionTestCase(unittest.HomeserverTestCase):
"""Test that time based eviction works correctly.""" """Test that time based eviction works correctly."""
def default_config(self): def default_config(self) -> JsonDict:
config = super().default_config() config = super().default_config()
config.setdefault("caches", {})["expiry_time"] = "30m" config.setdefault("caches", {})["expiry_time"] = "30m"
return config return config
def test_evict(self): def test_evict(self) -> None:
setup_expire_lru_cache_entries(self.hs) setup_expire_lru_cache_entries(self.hs)
cache = LruCache(5, clock=self.hs.get_clock()) cache: LruCache[str, int] = LruCache(5, clock=self.hs.get_clock())
# Check that we evict entries we haven't accessed for 30 minutes. # Check that we evict entries we haven't accessed for 30 minutes.
cache["key1"] = 1 cache["key1"] = 1
@ -332,7 +336,7 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase):
} }
) )
@patch("synapse.util.caches.lrucache.get_jemalloc_stats") @patch("synapse.util.caches.lrucache.get_jemalloc_stats")
def test_evict_memory(self, jemalloc_interface) -> None: def test_evict_memory(self, jemalloc_interface: Mock) -> None:
mock_jemalloc_class = Mock(spec=JemallocStats) mock_jemalloc_class = Mock(spec=JemallocStats)
jemalloc_interface.return_value = mock_jemalloc_class jemalloc_interface.return_value = mock_jemalloc_class
@ -340,7 +344,7 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase):
mock_jemalloc_class.get_stat.return_value = 924288000 mock_jemalloc_class.get_stat.return_value = 924288000
setup_expire_lru_cache_entries(self.hs) setup_expire_lru_cache_entries(self.hs)
cache = LruCache(4, clock=self.hs.get_clock()) cache: LruCache[str, int] = LruCache(4, clock=self.hs.get_clock())
cache["key1"] = 1 cache["key1"] = 1
cache["key2"] = 2 cache["key2"] = 2

View file

@ -21,14 +21,14 @@ from tests.unittest import TestCase
class MacaroonGeneratorTestCase(TestCase): class MacaroonGeneratorTestCase(TestCase):
def setUp(self): def setUp(self) -> None:
self.reactor, hs_clock = get_clock() self.reactor, hs_clock = get_clock()
self.macaroon_generator = MacaroonGenerator(hs_clock, "tesths", b"verysecret") self.macaroon_generator = MacaroonGenerator(hs_clock, "tesths", b"verysecret")
self.other_macaroon_generator = MacaroonGenerator( self.other_macaroon_generator = MacaroonGenerator(
hs_clock, "tesths", b"anothersecretkey" hs_clock, "tesths", b"anothersecretkey"
) )
def test_guest_access_token(self): def test_guest_access_token(self) -> None:
"""Test the generation and verification of guest access tokens""" """Test the generation and verification of guest access tokens"""
token = self.macaroon_generator.generate_guest_access_token("@user:tesths") token = self.macaroon_generator.generate_guest_access_token("@user:tesths")
user_id = self.macaroon_generator.verify_guest_token(token) user_id = self.macaroon_generator.verify_guest_token(token)
@ -47,7 +47,7 @@ class MacaroonGeneratorTestCase(TestCase):
with self.assertRaises(MacaroonVerificationFailedException): with self.assertRaises(MacaroonVerificationFailedException):
self.macaroon_generator.verify_guest_token(token) self.macaroon_generator.verify_guest_token(token)
def test_delete_pusher_token(self): def test_delete_pusher_token(self) -> None:
"""Test the generation and verification of delete_pusher tokens""" """Test the generation and verification of delete_pusher tokens"""
token = self.macaroon_generator.generate_delete_pusher_token( token = self.macaroon_generator.generate_delete_pusher_token(
"@user:tesths", "m.mail", "john@example.com" "@user:tesths", "m.mail", "john@example.com"
@ -84,7 +84,7 @@ class MacaroonGeneratorTestCase(TestCase):
) )
self.assertEqual(user_id, "@user:tesths") self.assertEqual(user_id, "@user:tesths")
def test_oidc_session_token(self): def test_oidc_session_token(self) -> None:
"""Test the generation and verification of OIDC session cookies""" """Test the generation and verification of OIDC session cookies"""
state = "arandomstate" state = "arandomstate"
session_data = OidcSessionData( session_data = OidcSessionData(

View file

@ -13,16 +13,19 @@
# limitations under the License. # limitations under the License.
from typing import Optional from typing import Optional
from twisted.internet.defer import Deferred
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRatelimitSettings
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from tests.server import get_clock from tests.server import ThreadedMemoryReactorClock, get_clock
from tests.unittest import TestCase from tests.unittest import TestCase
from tests.utils import default_config from tests.utils import default_config
class FederationRateLimiterTestCase(TestCase): class FederationRateLimiterTestCase(TestCase):
def test_ratelimit(self): def test_ratelimit(self) -> None:
"""A simple test with the default values""" """A simple test with the default values"""
reactor, clock = get_clock() reactor, clock = get_clock()
rc_config = build_rc_config() rc_config = build_rc_config()
@ -32,7 +35,7 @@ class FederationRateLimiterTestCase(TestCase):
# shouldn't block # shouldn't block
self.successResultOf(d1) self.successResultOf(d1)
def test_concurrent_limit(self): def test_concurrent_limit(self) -> None:
"""Test what happens when we hit the concurrent limit""" """Test what happens when we hit the concurrent limit"""
reactor, clock = get_clock() reactor, clock = get_clock()
rc_config = build_rc_config({"rc_federation": {"concurrent": 2}}) rc_config = build_rc_config({"rc_federation": {"concurrent": 2}})
@ -56,7 +59,7 @@ class FederationRateLimiterTestCase(TestCase):
cm2.__exit__(None, None, None) cm2.__exit__(None, None, None)
self.successResultOf(d3) self.successResultOf(d3)
def test_sleep_limit(self): def test_sleep_limit(self) -> None:
"""Test what happens when we hit the sleep limit""" """Test what happens when we hit the sleep limit"""
reactor, clock = get_clock() reactor, clock = get_clock()
rc_config = build_rc_config( rc_config = build_rc_config(
@ -79,7 +82,7 @@ class FederationRateLimiterTestCase(TestCase):
self.assertAlmostEqual(sleep_time, 500, places=3) self.assertAlmostEqual(sleep_time, 500, places=3)
def _await_resolution(reactor, d): def _await_resolution(reactor: ThreadedMemoryReactorClock, d: Deferred) -> float:
"""advance the clock until the deferred completes. """advance the clock until the deferred completes.
Returns the number of milliseconds it took to complete. Returns the number of milliseconds it took to complete.
@ -90,7 +93,7 @@ def _await_resolution(reactor, d):
return (reactor.seconds() - start_time) * 1000 return (reactor.seconds() - start_time) * 1000
def build_rc_config(settings: Optional[dict] = None): def build_rc_config(settings: Optional[dict] = None) -> FederationRatelimitSettings:
config_dict = default_config("test") config_dict = default_config("test")
config_dict.update(settings or {}) config_dict.update(settings or {})
config = HomeServerConfig() config = HomeServerConfig()

View file

@ -22,7 +22,7 @@ from tests.unittest import HomeserverTestCase
class RetryLimiterTestCase(HomeserverTestCase): class RetryLimiterTestCase(HomeserverTestCase):
def test_new_destination(self): def test_new_destination(self) -> None:
"""A happy-path case with a new destination and a successful operation""" """A happy-path case with a new destination and a successful operation"""
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
@ -36,7 +36,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertIsNone(new_timings) self.assertIsNone(new_timings)
def test_limiter(self): def test_limiter(self) -> None:
"""General test case which walks through the process of a failing request""" """General test case which walks through the process of a failing request"""
store = self.hs.get_datastores().main store = self.hs.get_datastores().main

View file

@ -49,7 +49,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
acquired_d: "Deferred[None]" = Deferred() acquired_d: "Deferred[None]" = Deferred()
unblock_d: "Deferred[None]" = Deferred() unblock_d: "Deferred[None]" = Deferred()
async def reader_or_writer(): async def reader_or_writer() -> str:
async with read_or_write(key): async with read_or_write(key):
acquired_d.callback(None) acquired_d.callback(None)
await unblock_d await unblock_d
@ -134,7 +134,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
d.called, msg="deferred %d was unexpectedly resolved" % (i + n) d.called, msg="deferred %d was unexpectedly resolved" % (i + n)
) )
def test_rwlock(self): def test_rwlock(self) -> None:
rwlock = ReadWriteLock() rwlock = ReadWriteLock()
key = "key" key = "key"
@ -197,7 +197,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
_, acquired_d = self._start_nonblocking_reader(rwlock, key, "last reader") _, acquired_d = self._start_nonblocking_reader(rwlock, key, "last reader")
self.assertTrue(acquired_d.called) self.assertTrue(acquired_d.called)
def test_lock_handoff_to_nonblocking_writer(self): def test_lock_handoff_to_nonblocking_writer(self) -> None:
"""Test a writer handing the lock to another writer that completes instantly.""" """Test a writer handing the lock to another writer that completes instantly."""
rwlock = ReadWriteLock() rwlock = ReadWriteLock()
key = "key" key = "key"
@ -216,7 +216,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
d3, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed") d3, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed")
self.assertTrue(d3.called) self.assertTrue(d3.called)
def test_cancellation_while_holding_read_lock(self): def test_cancellation_while_holding_read_lock(self) -> None:
"""Test cancellation while holding a read lock. """Test cancellation while holding a read lock.
A waiting writer should be given the lock when the reader holding the lock is A waiting writer should be given the lock when the reader holding the lock is
@ -242,7 +242,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
) )
self.assertEqual("write completed", self.successResultOf(writer_d)) self.assertEqual("write completed", self.successResultOf(writer_d))
def test_cancellation_while_holding_write_lock(self): def test_cancellation_while_holding_write_lock(self) -> None:
"""Test cancellation while holding a write lock. """Test cancellation while holding a write lock.
A waiting reader should be given the lock when the writer holding the lock is A waiting reader should be given the lock when the writer holding the lock is
@ -268,7 +268,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
) )
self.assertEqual("read completed", self.successResultOf(reader_d)) self.assertEqual("read completed", self.successResultOf(reader_d))
def test_cancellation_while_waiting_for_read_lock(self): def test_cancellation_while_waiting_for_read_lock(self) -> None:
"""Test cancellation while waiting for a read lock. """Test cancellation while waiting for a read lock.
Tests that cancelling a waiting reader: Tests that cancelling a waiting reader:
@ -319,7 +319,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
) )
self.assertEqual("write 2 completed", self.successResultOf(writer2_d)) self.assertEqual("write 2 completed", self.successResultOf(writer2_d))
def test_cancellation_while_waiting_for_write_lock(self): def test_cancellation_while_waiting_for_write_lock(self) -> None:
"""Test cancellation while waiting for a write lock. """Test cancellation while waiting for a write lock.
Tests that cancelling a waiting writer: Tests that cancelling a waiting writer:

View file

@ -8,7 +8,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
Tests for StreamChangeCache. Tests for StreamChangeCache.
""" """
def test_prefilled_cache(self): def test_prefilled_cache(self) -> None:
""" """
Providing a prefilled cache to StreamChangeCache will result in a cache Providing a prefilled cache to StreamChangeCache will result in a cache
with the prefilled-cache entered in. with the prefilled-cache entered in.
@ -16,7 +16,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
cache = StreamChangeCache("#test", 1, prefilled_cache={"user@foo.com": 2}) cache = StreamChangeCache("#test", 1, prefilled_cache={"user@foo.com": 2})
self.assertTrue(cache.has_entity_changed("user@foo.com", 1)) self.assertTrue(cache.has_entity_changed("user@foo.com", 1))
def test_has_entity_changed(self): def test_has_entity_changed(self) -> None:
""" """
StreamChangeCache.entity_has_changed will mark entities as changed, and StreamChangeCache.entity_has_changed will mark entities as changed, and
has_entity_changed will observe the changed entities. has_entity_changed will observe the changed entities.
@ -52,7 +52,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
self.assertTrue(cache.has_entity_changed("user@foo.com", 0)) self.assertTrue(cache.has_entity_changed("user@foo.com", 0))
self.assertTrue(cache.has_entity_changed("not@here.website", 0)) self.assertTrue(cache.has_entity_changed("not@here.website", 0))
def test_entity_has_changed_pops_off_start(self): def test_entity_has_changed_pops_off_start(self) -> None:
""" """
StreamChangeCache.entity_has_changed will respect the max size and StreamChangeCache.entity_has_changed will respect the max size and
purge the oldest items upon reaching that max size. purge the oldest items upon reaching that max size.
@ -86,7 +86,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
) )
self.assertIsNone(cache.get_all_entities_changed(1)) self.assertIsNone(cache.get_all_entities_changed(1))
def test_get_all_entities_changed(self): def test_get_all_entities_changed(self) -> None:
""" """
StreamChangeCache.get_all_entities_changed will return all changed StreamChangeCache.get_all_entities_changed will return all changed
entities since the given position. If the position is before the start entities since the given position. If the position is before the start
@ -142,7 +142,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
r = cache.get_all_entities_changed(3) r = cache.get_all_entities_changed(3)
self.assertTrue(r == ok1 or r == ok2) self.assertTrue(r == ok1 or r == ok2)
def test_has_any_entity_changed(self): def test_has_any_entity_changed(self) -> None:
""" """
StreamChangeCache.has_any_entity_changed will return True if any StreamChangeCache.has_any_entity_changed will return True if any
entities have been changed since the provided stream position, and entities have been changed since the provided stream position, and
@ -168,7 +168,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
self.assertFalse(cache.has_any_entity_changed(2)) self.assertFalse(cache.has_any_entity_changed(2))
self.assertFalse(cache.has_any_entity_changed(3)) self.assertFalse(cache.has_any_entity_changed(3))
def test_get_entities_changed(self): def test_get_entities_changed(self) -> None:
""" """
StreamChangeCache.get_entities_changed will return the entities in the StreamChangeCache.get_entities_changed will return the entities in the
given list that have changed since the provided stream ID. If the given list that have changed since the provided stream ID. If the
@ -228,7 +228,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
{"bar@baz.net"}, {"bar@baz.net"},
) )
def test_max_pos(self): def test_max_pos(self) -> None:
""" """
StreamChangeCache.get_max_pos_of_last_change will return the most StreamChangeCache.get_max_pos_of_last_change will return the most
recent point where the entity could have changed. If the entity is not recent point where the entity could have changed. If the entity is not

View file

@ -19,7 +19,7 @@ from .. import unittest
class StringUtilsTestCase(unittest.TestCase): class StringUtilsTestCase(unittest.TestCase):
def test_client_secret_regex(self): def test_client_secret_regex(self) -> None:
"""Ensure that client_secret does not contain illegal characters""" """Ensure that client_secret does not contain illegal characters"""
good = [ good = [
"abcde12345", "abcde12345",
@ -46,7 +46,7 @@ class StringUtilsTestCase(unittest.TestCase):
with self.assertRaises(SynapseError): with self.assertRaises(SynapseError):
assert_valid_client_secret(client_secret) assert_valid_client_secret(client_secret)
def test_base62_encode(self): def test_base62_encode(self) -> None:
self.assertEqual("0", base62_encode(0)) self.assertEqual("0", base62_encode(0))
self.assertEqual("10", base62_encode(62)) self.assertEqual("10", base62_encode(62))
self.assertEqual("1c", base62_encode(100)) self.assertEqual("1c", base62_encode(100))

View file

@ -18,31 +18,31 @@ from tests.unittest import HomeserverTestCase
class CanonicaliseEmailTests(HomeserverTestCase): class CanonicaliseEmailTests(HomeserverTestCase):
def test_no_at(self): def test_no_at(self) -> None:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
canonicalise_email("address-without-at.bar") canonicalise_email("address-without-at.bar")
def test_two_at(self): def test_two_at(self) -> None:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
canonicalise_email("foo@foo@test.bar") canonicalise_email("foo@foo@test.bar")
def test_bad_format(self): def test_bad_format(self) -> None:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
canonicalise_email("user@bad.example.net@good.example.com") canonicalise_email("user@bad.example.net@good.example.com")
def test_valid_format(self): def test_valid_format(self) -> None:
self.assertEqual(canonicalise_email("foo@test.bar"), "foo@test.bar") self.assertEqual(canonicalise_email("foo@test.bar"), "foo@test.bar")
def test_domain_to_lower(self): def test_domain_to_lower(self) -> None:
self.assertEqual(canonicalise_email("foo@TEST.BAR"), "foo@test.bar") self.assertEqual(canonicalise_email("foo@TEST.BAR"), "foo@test.bar")
def test_domain_with_umlaut(self): def test_domain_with_umlaut(self) -> None:
self.assertEqual(canonicalise_email("foo@Öumlaut.com"), "foo@öumlaut.com") self.assertEqual(canonicalise_email("foo@Öumlaut.com"), "foo@öumlaut.com")
def test_address_casefold(self): def test_address_casefold(self) -> None:
self.assertEqual( self.assertEqual(
canonicalise_email("Strauß@Example.com"), "strauss@example.com" canonicalise_email("Strauß@Example.com"), "strauss@example.com"
) )
def test_address_trim(self): def test_address_trim(self) -> None:
self.assertEqual(canonicalise_email(" foo@test.bar "), "foo@test.bar") self.assertEqual(canonicalise_email(" foo@test.bar "), "foo@test.bar")

View file

@ -19,7 +19,7 @@ from .. import unittest
class TreeCacheTestCase(unittest.TestCase): class TreeCacheTestCase(unittest.TestCase):
def test_get_set_onelevel(self): def test_get_set_onelevel(self) -> None:
cache = TreeCache() cache = TreeCache()
cache[("a",)] = "A" cache[("a",)] = "A"
cache[("b",)] = "B" cache[("b",)] = "B"
@ -27,7 +27,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual(cache.get(("b",)), "B") self.assertEqual(cache.get(("b",)), "B")
self.assertEqual(len(cache), 2) self.assertEqual(len(cache), 2)
def test_pop_onelevel(self): def test_pop_onelevel(self) -> None:
cache = TreeCache() cache = TreeCache()
cache[("a",)] = "A" cache[("a",)] = "A"
cache[("b",)] = "B" cache[("b",)] = "B"
@ -36,7 +36,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual(cache.get(("b",)), "B") self.assertEqual(cache.get(("b",)), "B")
self.assertEqual(len(cache), 1) self.assertEqual(len(cache), 1)
def test_get_set_twolevel(self): def test_get_set_twolevel(self) -> None:
cache = TreeCache() cache = TreeCache()
cache[("a", "a")] = "AA" cache[("a", "a")] = "AA"
cache[("a", "b")] = "AB" cache[("a", "b")] = "AB"
@ -46,7 +46,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual(cache.get(("b", "a")), "BA") self.assertEqual(cache.get(("b", "a")), "BA")
self.assertEqual(len(cache), 3) self.assertEqual(len(cache), 3)
def test_pop_twolevel(self): def test_pop_twolevel(self) -> None:
cache = TreeCache() cache = TreeCache()
cache[("a", "a")] = "AA" cache[("a", "a")] = "AA"
cache[("a", "b")] = "AB" cache[("a", "b")] = "AB"
@ -58,7 +58,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual(cache.pop(("b", "a")), None) self.assertEqual(cache.pop(("b", "a")), None)
self.assertEqual(len(cache), 1) self.assertEqual(len(cache), 1)
def test_pop_mixedlevel(self): def test_pop_mixedlevel(self) -> None:
cache = TreeCache() cache = TreeCache()
cache[("a", "a")] = "AA" cache[("a", "a")] = "AA"
cache[("a", "b")] = "AB" cache[("a", "b")] = "AB"
@ -72,14 +72,14 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual({"AA", "AB"}, set(iterate_tree_cache_entry(popped))) self.assertEqual({"AA", "AB"}, set(iterate_tree_cache_entry(popped)))
def test_clear(self): def test_clear(self) -> None:
cache = TreeCache() cache = TreeCache()
cache[("a",)] = "A" cache[("a",)] = "A"
cache[("b",)] = "B" cache[("b",)] = "B"
cache.clear() cache.clear()
self.assertEqual(len(cache), 0) self.assertEqual(len(cache), 0)
def test_contains(self): def test_contains(self) -> None:
cache = TreeCache() cache = TreeCache()
cache[("a",)] = "A" cache[("a",)] = "A"
self.assertTrue(("a",) in cache) self.assertTrue(("a",) in cache)

View file

@ -18,8 +18,8 @@ from .. import unittest
class WheelTimerTestCase(unittest.TestCase): class WheelTimerTestCase(unittest.TestCase):
def test_single_insert_fetch(self): def test_single_insert_fetch(self) -> None:
wheel = WheelTimer(bucket_size=5) wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
obj = object() obj = object()
wheel.insert(100, obj, 150) wheel.insert(100, obj, 150)
@ -32,8 +32,8 @@ class WheelTimerTestCase(unittest.TestCase):
self.assertListEqual(wheel.fetch(156), [obj]) self.assertListEqual(wheel.fetch(156), [obj])
self.assertListEqual(wheel.fetch(170), []) self.assertListEqual(wheel.fetch(170), [])
def test_multi_insert(self): def test_multi_insert(self) -> None:
wheel = WheelTimer(bucket_size=5) wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
obj1 = object() obj1 = object()
obj2 = object() obj2 = object()
@ -50,15 +50,15 @@ class WheelTimerTestCase(unittest.TestCase):
self.assertListEqual(wheel.fetch(200), [obj3]) self.assertListEqual(wheel.fetch(200), [obj3])
self.assertListEqual(wheel.fetch(210), []) self.assertListEqual(wheel.fetch(210), [])
def test_insert_past(self): def test_insert_past(self) -> None:
wheel = WheelTimer(bucket_size=5) wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
obj = object() obj = object()
wheel.insert(100, obj, 50) wheel.insert(100, obj, 50)
self.assertListEqual(wheel.fetch(120), [obj]) self.assertListEqual(wheel.fetch(120), [obj])
def test_insert_past_multi(self): def test_insert_past_multi(self) -> None:
wheel = WheelTimer(bucket_size=5) wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
obj1 = object() obj1 = object()
obj2 = object() obj2 = object()