Add type hints to tests files. (#12256)

This commit is contained in:
Dirk Klimpel 2022-03-21 14:43:16 +01:00 committed by GitHub
parent 0a59f977a2
commit 9d21ecf7ce
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 101 additions and 88 deletions

1
changelog.d/12256.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to tests files.

View file

@ -82,9 +82,7 @@ exclude = (?x)
|tests/server.py |tests/server.py
|tests/server_notices/test_resource_limits_server_notices.py |tests/server_notices/test_resource_limits_server_notices.py
|tests/state/test_v2.py |tests/state/test_v2.py
|tests/storage/test_background_update.py
|tests/storage/test_base.py |tests/storage/test_base.py
|tests/storage/test_id_generators.py
|tests/storage/test_roommember.py |tests/storage/test_roommember.py
|tests/test_metrics.py |tests/test_metrics.py
|tests/test_phone_home.py |tests/test_phone_home.py

View file

@ -18,11 +18,14 @@ from typing import Dict
from unittest.mock import ANY, Mock, call from unittest.mock import ANY, Mock, call
from twisted.internet import defer from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource from twisted.web.resource import Resource
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.federation.transport.server import TransportLayerServer from synapse.federation.transport.server import TransportLayerServer
from synapse.types import UserID, create_requester from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
@ -42,7 +45,9 @@ ROOM_ID = "a-room"
OTHER_ROOM_ID = "another-room" OTHER_ROOM_ID = "another-room"
def _expect_edu_transaction(edu_type, content, origin="test"): def _expect_edu_transaction(
edu_type: str, content: JsonDict, origin: str = "test"
) -> JsonDict:
return { return {
"origin": origin, "origin": origin,
"origin_server_ts": 1000000, "origin_server_ts": 1000000,
@ -51,12 +56,12 @@ def _expect_edu_transaction(edu_type, content, origin="test"):
} }
def _make_edu_transaction_json(edu_type, content): def _make_edu_transaction_json(edu_type: str, content: JsonDict) -> bytes:
return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8") return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8")
class TypingNotificationsTestCase(unittest.HomeserverTestCase): class TypingNotificationsTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# we mock out the keyring so as to skip the authentication check on the # we mock out the keyring so as to skip the authentication check on the
# federation API call. # federation API call.
mock_keyring = Mock(spec=["verify_json_for_server"]) mock_keyring = Mock(spec=["verify_json_for_server"])
@ -83,7 +88,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
d["/_matrix/federation"] = TransportLayerServer(self.hs) d["/_matrix/federation"] = TransportLayerServer(self.hs)
return d return d
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
mock_notifier = hs.get_notifier() mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event self.on_new_event = mock_notifier.on_new_event
@ -111,24 +116,24 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.room_members = [] self.room_members = []
async def check_user_in_room(room_id, user_id): async def check_user_in_room(room_id: str, user_id: str) -> None:
if user_id not in [u.to_string() for u in self.room_members]: if user_id not in [u.to_string() for u in self.room_members]:
raise AuthError(401, "User is not in the room") raise AuthError(401, "User is not in the room")
return None return None
hs.get_auth().check_user_in_room = check_user_in_room hs.get_auth().check_user_in_room = check_user_in_room
async def check_host_in_room(room_id, server_name): async def check_host_in_room(room_id: str, server_name: str) -> bool:
return room_id == ROOM_ID return room_id == ROOM_ID
hs.get_event_auth_handler().check_host_in_room = check_host_in_room hs.get_event_auth_handler().check_host_in_room = check_host_in_room
def get_joined_hosts_for_room(room_id): def get_joined_hosts_for_room(room_id: str):
return {member.domain for member in self.room_members} return {member.domain for member in self.room_members}
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
async def get_users_in_room(room_id): async def get_users_in_room(room_id: str):
return {str(u) for u in self.room_members} return {str(u) for u in self.room_members}
self.datastore.get_users_in_room = get_users_in_room self.datastore.get_users_in_room = get_users_in_room
@ -153,7 +158,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
lambda *args, **kwargs: make_awaitable(None) lambda *args, **kwargs: make_awaitable(None)
) )
def test_started_typing_local(self): def test_started_typing_local(self) -> None:
self.room_members = [U_APPLE, U_BANANA] self.room_members = [U_APPLE, U_BANANA]
self.assertEqual(self.event_source.get_current_key(), 0) self.assertEqual(self.event_source.get_current_key(), 0)
@ -187,7 +192,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
) )
@override_config({"send_federation": True}) @override_config({"send_federation": True})
def test_started_typing_remote_send(self): def test_started_typing_remote_send(self) -> None:
self.room_members = [U_APPLE, U_ONION] self.room_members = [U_APPLE, U_ONION]
self.get_success( self.get_success(
@ -217,7 +222,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
try_trailing_slash_on_400=True, try_trailing_slash_on_400=True,
) )
def test_started_typing_remote_recv(self): def test_started_typing_remote_recv(self) -> None:
self.room_members = [U_APPLE, U_ONION] self.room_members = [U_APPLE, U_ONION]
self.assertEqual(self.event_source.get_current_key(), 0) self.assertEqual(self.event_source.get_current_key(), 0)
@ -256,7 +261,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
], ],
) )
def test_started_typing_remote_recv_not_in_room(self): def test_started_typing_remote_recv_not_in_room(self) -> None:
self.room_members = [U_APPLE, U_ONION] self.room_members = [U_APPLE, U_ONION]
self.assertEqual(self.event_source.get_current_key(), 0) self.assertEqual(self.event_source.get_current_key(), 0)
@ -292,7 +297,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEqual(events[1], 0) self.assertEqual(events[1], 0)
@override_config({"send_federation": True}) @override_config({"send_federation": True})
def test_stopped_typing(self): def test_stopped_typing(self) -> None:
self.room_members = [U_APPLE, U_BANANA, U_ONION] self.room_members = [U_APPLE, U_BANANA, U_ONION]
# Gut-wrenching # Gut-wrenching
@ -343,7 +348,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
[{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
) )
def test_typing_timeout(self): def test_typing_timeout(self) -> None:
self.room_members = [U_APPLE, U_BANANA] self.room_members = [U_APPLE, U_BANANA]
self.assertEqual(self.event_source.get_current_key(), 0) self.assertEqual(self.event_source.get_current_key(), 0)

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict from typing import Dict, Optional, Union
import frozendict import frozendict
@ -20,12 +20,13 @@ from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.push import push_rule_evaluator from synapse.push import push_rule_evaluator
from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.types import JsonDict
from tests import unittest from tests import unittest
class PushRuleEvaluatorTestCase(unittest.TestCase): class PushRuleEvaluatorTestCase(unittest.TestCase):
def _get_evaluator(self, content): def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluatorForEvent:
event = FrozenEvent( event = FrozenEvent(
{ {
"event_id": "$event_id", "event_id": "$event_id",
@ -39,12 +40,12 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
) )
room_member_count = 0 room_member_count = 0
sender_power_level = 0 sender_power_level = 0
power_levels = {} power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
return PushRuleEvaluatorForEvent( return PushRuleEvaluatorForEvent(
event, room_member_count, sender_power_level, power_levels event, room_member_count, sender_power_level, power_levels
) )
def test_display_name(self): def test_display_name(self) -> None:
"""Check for a matching display name in the body of the event.""" """Check for a matching display name in the body of the event."""
evaluator = self._get_evaluator({"body": "foo bar baz"}) evaluator = self._get_evaluator({"body": "foo bar baz"})
@ -71,20 +72,20 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar")) self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
def _assert_matches( def _assert_matches(
self, condition: Dict[str, Any], content: Dict[str, Any], msg=None self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
) -> None: ) -> None:
evaluator = self._get_evaluator(content) evaluator = self._get_evaluator(content)
self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg) self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg)
def _assert_not_matches( def _assert_not_matches(
self, condition: Dict[str, Any], content: Dict[str, Any], msg=None self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
) -> None: ) -> None:
evaluator = self._get_evaluator(content) evaluator = self._get_evaluator(content)
self.assertFalse( self.assertFalse(
evaluator.matches(condition, "@user:test", "display_name"), msg evaluator.matches(condition, "@user:test", "display_name"), msg
) )
def test_event_match_body(self): def test_event_match_body(self) -> None:
"""Check that event_match conditions on content.body work as expected""" """Check that event_match conditions on content.body work as expected"""
# if the key is `content.body`, the pattern matches substrings. # if the key is `content.body`, the pattern matches substrings.
@ -165,7 +166,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
r"? after \ should match any character", r"? after \ should match any character",
) )
def test_event_match_non_body(self): def test_event_match_non_body(self) -> None:
"""Check that event_match conditions on other keys work as expected""" """Check that event_match conditions on other keys work as expected"""
# if the key is anything other than 'content.body', the pattern must match the # if the key is anything other than 'content.body', the pattern must match the
@ -241,7 +242,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"pattern should not match before a newline", "pattern should not match before a newline",
) )
def test_no_body(self): def test_no_body(self) -> None:
"""Not having a body shouldn't break the evaluator.""" """Not having a body shouldn't break the evaluator."""
evaluator = self._get_evaluator({}) evaluator = self._get_evaluator({})
@ -250,7 +251,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
} }
self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
def test_invalid_body(self): def test_invalid_body(self) -> None:
"""A non-string body should not break the evaluator.""" """A non-string body should not break the evaluator."""
condition = { condition = {
"kind": "contains_display_name", "kind": "contains_display_name",
@ -260,7 +261,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
evaluator = self._get_evaluator({"body": body}) evaluator = self._get_evaluator({"body": body})
self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
def test_tweaks_for_actions(self): def test_tweaks_for_actions(self) -> None:
""" """
This tests the behaviour of tweaks_for_actions. This tests the behaviour of tweaks_for_actions.
""" """

View file

@ -17,8 +17,12 @@ from unittest.mock import Mock
import yaml import yaml
from twisted.internet.defer import Deferred, ensureDeferred from twisted.internet.defer import Deferred, ensureDeferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.background_updates import BackgroundUpdater
from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable, simple_async_mock from tests.test_utils import make_awaitable, simple_async_mock
@ -26,7 +30,7 @@ from tests.unittest import override_config
class BackgroundUpdateTestCase(unittest.HomeserverTestCase): class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates
# the base test class should have run the real bg updates for us # the base test class should have run the real bg updates for us
self.assertTrue( self.assertTrue(
@ -39,7 +43,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
) )
self.store = self.hs.get_datastores().main self.store = self.hs.get_datastores().main
async def update(self, progress, count): async def update(self, progress: JsonDict, count: int) -> int:
duration_ms = 10 duration_ms = 10
await self.clock.sleep((count * duration_ms) / 1000) await self.clock.sleep((count * duration_ms) / 1000)
progress = {"my_key": progress["my_key"] + 1} progress = {"my_key": progress["my_key"] + 1}
@ -51,7 +55,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
) )
return count return count
def test_do_background_update(self): def test_do_background_update(self) -> None:
# the time we claim it takes to update one item when running the update # the time we claim it takes to update one item when running the update
duration_ms = 10 duration_ms = 10
@ -80,7 +84,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
# second step: complete the update # second step: complete the update
# we should now get run with a much bigger number of items to update # we should now get run with a much bigger number of items to update
async def update(progress, count): async def update(progress: JsonDict, count: int) -> int:
self.assertEqual(progress, {"my_key": 2}) self.assertEqual(progress, {"my_key": 2})
self.assertAlmostEqual( self.assertAlmostEqual(
count, count,
@ -110,7 +114,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
""" """
) )
) )
def test_background_update_default_batch_set_by_config(self): def test_background_update_default_batch_set_by_config(self) -> None:
""" """
Test that the background update is run with the default_batch_size set by the config Test that the background update is run with the default_batch_size set by the config
""" """
@ -133,7 +137,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
# on the first call, we should get run with the default background update size specified in the config # on the first call, we should get run with the default background update size specified in the config
self.update_handler.assert_called_once_with({"my_key": 1}, 20) self.update_handler.assert_called_once_with({"my_key": 1}, 20)
def test_background_update_default_sleep_behavior(self): def test_background_update_default_sleep_behavior(self) -> None:
""" """
Test default background update behavior, which is to sleep Test default background update behavior, which is to sleep
""" """
@ -147,7 +151,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.update_handler.side_effect = self.update self.update_handler.side_effect = self.update
self.update_handler.reset_mock() self.update_handler.reset_mock()
self.updates.start_doing_background_updates(), self.updates.start_doing_background_updates()
# 2: advance the reactor less than the default sleep duration (1000ms) # 2: advance the reactor less than the default sleep duration (1000ms)
self.reactor.pump([0.5]) self.reactor.pump([0.5])
@ -167,7 +171,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
""" """
) )
) )
def test_background_update_sleep_set_in_config(self): def test_background_update_sleep_set_in_config(self) -> None:
""" """
Test that changing the sleep time in the config changes how long it sleeps Test that changing the sleep time in the config changes how long it sleeps
""" """
@ -181,7 +185,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.update_handler.side_effect = self.update self.update_handler.side_effect = self.update
self.update_handler.reset_mock() self.update_handler.reset_mock()
self.updates.start_doing_background_updates(), self.updates.start_doing_background_updates()
# 2: advance the reactor less than the configured sleep duration (500ms) # 2: advance the reactor less than the configured sleep duration (500ms)
self.reactor.pump([0.45]) self.reactor.pump([0.45])
@ -201,7 +205,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
""" """
) )
) )
def test_disabling_background_update_sleep(self): def test_disabling_background_update_sleep(self) -> None:
""" """
Test that disabling sleep in the config results in bg update not sleeping Test that disabling sleep in the config results in bg update not sleeping
""" """
@ -215,7 +219,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
self.update_handler.side_effect = self.update self.update_handler.side_effect = self.update
self.update_handler.reset_mock() self.update_handler.reset_mock()
self.updates.start_doing_background_updates(), self.updates.start_doing_background_updates()
# 2: advance the reactor very little # 2: advance the reactor very little
self.reactor.pump([0.025]) self.reactor.pump([0.025])
@ -230,7 +234,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
""" """
) )
) )
def test_background_update_duration_set_in_config(self): def test_background_update_duration_set_in_config(self) -> None:
""" """
Test that the desired duration set in the config is used in determining batch size Test that the desired duration set in the config is used in determining batch size
""" """
@ -254,7 +258,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
# the first update was run with the default batch size, this should be run with 500ms as the # the first update was run with the default batch size, this should be run with 500ms as the
# desired duration # desired duration
async def update(progress, count): async def update(progress: JsonDict, count: int) -> int:
self.assertEqual(progress, {"my_key": 2}) self.assertEqual(progress, {"my_key": 2})
self.assertAlmostEqual( self.assertAlmostEqual(
count, count,
@ -275,7 +279,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
""" """
) )
) )
def test_background_update_min_batch_set_in_config(self): def test_background_update_min_batch_set_in_config(self) -> None:
""" """
Test that the minimum batch size set in the config is used Test that the minimum batch size set in the config is used
""" """
@ -290,7 +294,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
) )
# Run the update with the long-running update item # Run the update with the long-running update item
async def update(progress, count): async def update_long(progress: JsonDict, count: int) -> int:
await self.clock.sleep((count * duration_ms) / 1000) await self.clock.sleep((count * duration_ms) / 1000)
progress = {"my_key": progress["my_key"] + 1} progress = {"my_key": progress["my_key"] + 1}
await self.store.db_pool.runInteraction( await self.store.db_pool.runInteraction(
@ -301,7 +305,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
) )
return count return count
self.update_handler.side_effect = update self.update_handler.side_effect = update_long
self.update_handler.reset_mock() self.update_handler.reset_mock()
res = self.get_success( res = self.get_success(
self.updates.do_next_background_update(False), self.updates.do_next_background_update(False),
@ -311,25 +315,25 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
# the first update was run with the default batch size, this should be run with minimum batch size # the first update was run with the default batch size, this should be run with minimum batch size
# as the first items took a very long time # as the first items took a very long time
async def update(progress, count): async def update_short(progress: JsonDict, count: int) -> int:
self.assertEqual(progress, {"my_key": 2}) self.assertEqual(progress, {"my_key": 2})
self.assertEqual(count, 5) self.assertEqual(count, 5)
await self.updates._end_background_update("test_update") await self.updates._end_background_update("test_update")
return count return count
self.update_handler.side_effect = update self.update_handler.side_effect = update_short
self.get_success(self.updates.do_next_background_update(False)) self.get_success(self.updates.do_next_background_update(False))
class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates
# the base test class should have run the real bg updates for us # the base test class should have run the real bg updates for us
self.assertTrue( self.assertTrue(
self.get_success(self.updates.has_completed_background_updates()) self.get_success(self.updates.has_completed_background_updates())
) )
self.update_deferred = Deferred() self.update_deferred: Deferred[int] = Deferred()
self.update_handler = Mock(return_value=self.update_deferred) self.update_handler = Mock(return_value=self.update_deferred)
self.updates.register_background_update_handler( self.updates.register_background_update_handler(
"test_update", self.update_handler "test_update", self.update_handler
@ -358,7 +362,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
), ),
) )
def test_controller(self): def test_controller(self) -> None:
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
self.get_success( self.get_success(
store.db_pool.simple_insert( store.db_pool.simple_insert(
@ -368,7 +372,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
) )
# Set the return value for the context manager. # Set the return value for the context manager.
enter_defer = Deferred() enter_defer: Deferred[int] = Deferred()
self._update_ctx_manager.__aenter__ = Mock(return_value=enter_defer) self._update_ctx_manager.__aenter__ = Mock(return_value=enter_defer)
# Start the background update. # Start the background update.

View file

@ -13,9 +13,13 @@
# limitations under the License. # limitations under the License.
from typing import List, Optional from typing import List, Optional
from synapse.storage.database import DatabasePool from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.util import Clock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
from tests.utils import USE_POSTGRES_FOR_TESTS from tests.utils import USE_POSTGRES_FOR_TESTS
@ -25,13 +29,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS: if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres" skip = "Requires Postgres"
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
def _setup_db(self, txn): def _setup_db(self, txn: LoggingTransaction) -> None:
txn.execute("CREATE SEQUENCE foobar_seq") txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute( txn.execute(
""" """
@ -59,12 +63,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
def _insert_rows(self, instance_name: str, number: int): def _insert_rows(self, instance_name: str, number: int) -> None:
"""Insert N rows as the given instance, inserting with stream IDs pulled """Insert N rows as the given instance, inserting with stream IDs pulled
from the postgres sequence. from the postgres sequence.
""" """
def _insert(txn): def _insert(txn: LoggingTransaction) -> None:
for _ in range(number): for _ in range(number):
txn.execute( txn.execute(
"INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)", "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
@ -80,12 +84,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
def _insert_row_with_id(self, instance_name: str, stream_id: int): def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None:
"""Insert one row as the given instance with given stream_id, updating """Insert one row as the given instance with given stream_id, updating
the postgres sequence position to match. the postgres sequence position to match.
""" """
def _insert(txn): def _insert(txn: LoggingTransaction) -> None:
txn.execute( txn.execute(
"INSERT INTO foobar VALUES (?, ?)", "INSERT INTO foobar VALUES (?, ?)",
( (
@ -104,7 +108,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert)) self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
def test_empty(self): def test_empty(self) -> None:
"""Test an ID generator against an empty database gives sensible """Test an ID generator against an empty database gives sensible
current positions. current positions.
""" """
@ -114,7 +118,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# The table is empty so we expect an empty map for positions # The table is empty so we expect an empty map for positions
self.assertEqual(id_gen.get_positions(), {}) self.assertEqual(id_gen.get_positions(), {})
def test_single_instance(self): def test_single_instance(self) -> None:
"""Test that reads and writes from a single process are handled """Test that reads and writes from a single process are handled
correctly. correctly.
""" """
@ -130,7 +134,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# Try allocating a new ID gen and check that we only see position # Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager. # advanced after we leave the context manager.
async def _get_next_async(): async def _get_next_async() -> None:
async with id_gen.get_next() as stream_id: async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8) self.assertEqual(stream_id, 8)
@ -142,7 +146,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 8}) self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
def test_out_of_order_finish(self): def test_out_of_order_finish(self) -> None:
"""Test that IDs persisted out of order are correctly handled""" """Test that IDs persisted out of order are correctly handled"""
# Prefill table with 7 rows written by 'master' # Prefill table with 7 rows written by 'master'
@ -191,7 +195,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 11}) self.assertEqual(id_gen.get_positions(), {"master": 11})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 11) self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
def test_multi_instance(self): def test_multi_instance(self) -> None:
"""Test that reads and writes from multiple processes are handled """Test that reads and writes from multiple processes are handled
correctly. correctly.
""" """
@ -215,7 +219,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# Try allocating a new ID gen and check that we only see position # Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager. # advanced after we leave the context manager.
async def _get_next_async(): async def _get_next_async() -> None:
async with first_id_gen.get_next() as stream_id: async with first_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8) self.assertEqual(stream_id, 8)
@ -233,7 +237,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# ... but calling `get_next` on the second instance should give a unique # ... but calling `get_next` on the second instance should give a unique
# stream ID # stream ID
async def _get_next_async(): async def _get_next_async2() -> None:
async with second_id_gen.get_next() as stream_id: async with second_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 9) self.assertEqual(stream_id, 9)
@ -241,7 +245,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen.get_positions(), {"first": 3, "second": 7} second_id_gen.get_positions(), {"first": 3, "second": 7}
) )
self.get_success(_get_next_async()) self.get_success(_get_next_async2())
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9}) self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
@ -249,7 +253,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen.advance("first", 8) second_id_gen.advance("first", 8)
self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9}) self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
def test_get_next_txn(self): def test_get_next_txn(self) -> None:
"""Test that the `get_next_txn` function works correctly.""" """Test that the `get_next_txn` function works correctly."""
# Prefill table with 7 rows written by 'master' # Prefill table with 7 rows written by 'master'
@ -263,7 +267,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# Try allocating a new ID gen and check that we only see position # Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager. # advanced after we leave the context manager.
def _get_next_txn(txn): def _get_next_txn(txn: LoggingTransaction) -> None:
stream_id = id_gen.get_next_txn(txn) stream_id = id_gen.get_next_txn(txn)
self.assertEqual(stream_id, 8) self.assertEqual(stream_id, 8)
@ -275,7 +279,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 8}) self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
def test_get_persisted_upto_position(self): def test_get_persisted_upto_position(self) -> None:
"""Test that `get_persisted_upto_position` correctly tracks updates to """Test that `get_persisted_upto_position` correctly tracks updates to
positions. positions.
""" """
@ -317,7 +321,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen.advance("second", 15) id_gen.advance("second", 15)
self.assertEqual(id_gen.get_persisted_upto_position(), 11) self.assertEqual(id_gen.get_persisted_upto_position(), 11)
def test_get_persisted_upto_position_get_next(self): def test_get_persisted_upto_position_get_next(self) -> None:
"""Test that `get_persisted_upto_position` correctly tracks updates to """Test that `get_persisted_upto_position` correctly tracks updates to
positions when `get_next` is called. positions when `get_next` is called.
""" """
@ -331,7 +335,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_persisted_upto_position(), 5) self.assertEqual(id_gen.get_persisted_upto_position(), 5)
async def _get_next_async(): async def _get_next_async() -> None:
async with id_gen.get_next() as stream_id: async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 6) self.assertEqual(stream_id, 6)
self.assertEqual(id_gen.get_persisted_upto_position(), 5) self.assertEqual(id_gen.get_persisted_upto_position(), 5)
@ -344,7 +348,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# `persisted_upto_position` in this case, then it will be correct in the # `persisted_upto_position` in this case, then it will be correct in the
# other cases that are tested above (since they'll hit the same code). # other cases that are tested above (since they'll hit the same code).
def test_restart_during_out_of_order_persistence(self): def test_restart_during_out_of_order_persistence(self) -> None:
"""Test that restarting a process while another process is writing out """Test that restarting a process while another process is writing out
of order updates are handled correctly. of order updates are handled correctly.
""" """
@ -388,7 +392,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen_worker.advance("master", 9) id_gen_worker.advance("master", 9)
self.assertEqual(id_gen_worker.get_positions(), {"master": 9}) self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
def test_writer_config_change(self): def test_writer_config_change(self) -> None:
"""Test that changing the writer config correctly works.""" """Test that changing the writer config correctly works."""
self._insert_row_with_id("first", 3) self._insert_row_with_id("first", 3)
@ -421,7 +425,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# Check that we get a sane next stream ID with this new config. # Check that we get a sane next stream ID with this new config.
async def _get_next_async(): async def _get_next_async() -> None:
async with id_gen_3.get_next() as stream_id: async with id_gen_3.get_next() as stream_id:
self.assertEqual(stream_id, 6) self.assertEqual(stream_id, 6)
@ -435,7 +439,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen_5.get_current_token_for_writer("first"), 6) self.assertEqual(id_gen_5.get_current_token_for_writer("first"), 6)
self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6) self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6)
def test_sequence_consistency(self): def test_sequence_consistency(self) -> None:
"""Test that we error out if the table and sequence diverges.""" """Test that we error out if the table and sequence diverges."""
# Prefill with some rows # Prefill with some rows
@ -458,13 +462,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS: if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres" skip = "Requires Postgres"
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
def _setup_db(self, txn): def _setup_db(self, txn: LoggingTransaction) -> None:
txn.execute("CREATE SEQUENCE foobar_seq") txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute( txn.execute(
""" """
@ -493,10 +497,10 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
return self.get_success(self.db_pool.runWithConnection(_create)) return self.get_success(self.db_pool.runWithConnection(_create))
def _insert_row(self, instance_name: str, stream_id: int): def _insert_row(self, instance_name: str, stream_id: int) -> None:
"""Insert one row as the given instance with given stream_id.""" """Insert one row as the given instance with given stream_id."""
def _insert(txn): def _insert(txn: LoggingTransaction) -> None:
txn.execute( txn.execute(
"INSERT INTO foobar VALUES (?, ?)", "INSERT INTO foobar VALUES (?, ?)",
( (
@ -514,13 +518,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(self.db_pool.runInteraction("_insert_row", _insert)) self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
def test_single_instance(self): def test_single_instance(self) -> None:
"""Test that reads and writes from a single process are handled """Test that reads and writes from a single process are handled
correctly. correctly.
""" """
id_gen = self._create_id_generator() id_gen = self._create_id_generator()
async def _get_next_async(): async def _get_next_async() -> None:
async with id_gen.get_next() as stream_id: async with id_gen.get_next() as stream_id:
self._insert_row("master", stream_id) self._insert_row("master", stream_id)
@ -530,7 +534,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_current_token_for_writer("master"), -1) self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
self.assertEqual(id_gen.get_persisted_upto_position(), -1) self.assertEqual(id_gen.get_persisted_upto_position(), -1)
async def _get_next_async2(): async def _get_next_async2() -> None:
async with id_gen.get_next_mult(3) as stream_ids: async with id_gen.get_next_mult(3) as stream_ids:
for stream_id in stream_ids: for stream_id in stream_ids:
self._insert_row("master", stream_id) self._insert_row("master", stream_id)
@ -548,14 +552,14 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4) self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
self.assertEqual(second_id_gen.get_persisted_upto_position(), -4) self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)
def test_multiple_instance(self): def test_multiple_instance(self) -> None:
"""Tests that having multiple instances that get advanced over """Tests that having multiple instances that get advanced over
federation works corretly. federation works corretly.
""" """
id_gen_1 = self._create_id_generator("first", writers=["first", "second"]) id_gen_1 = self._create_id_generator("first", writers=["first", "second"])
id_gen_2 = self._create_id_generator("second", writers=["first", "second"]) id_gen_2 = self._create_id_generator("second", writers=["first", "second"])
async def _get_next_async(): async def _get_next_async() -> None:
async with id_gen_1.get_next() as stream_id: async with id_gen_1.get_next() as stream_id:
self._insert_row("first", stream_id) self._insert_row("first", stream_id)
id_gen_2.advance("first", stream_id) id_gen_2.advance("first", stream_id)
@ -567,7 +571,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1) self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1) self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
async def _get_next_async2(): async def _get_next_async2() -> None:
async with id_gen_2.get_next() as stream_id: async with id_gen_2.get_next() as stream_id:
self._insert_row("second", stream_id) self._insert_row("second", stream_id)
id_gen_1.advance("second", stream_id) id_gen_1.advance("second", stream_id)
@ -584,13 +588,13 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS: if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres" skip = "Requires Postgres"
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
def _setup_db(self, txn): def _setup_db(self, txn: LoggingTransaction) -> None:
txn.execute("CREATE SEQUENCE foobar_seq") txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute( txn.execute(
""" """
@ -642,7 +646,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
from the postgres sequence. from the postgres sequence.
""" """
def _insert(txn): def _insert(txn: LoggingTransaction) -> None:
for _ in range(number): for _ in range(number):
txn.execute( txn.execute(
"INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,), "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
@ -659,7 +663,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
def test_load_existing_stream(self): def test_load_existing_stream(self) -> None:
"""Test creating ID gens with multiple tables that have rows from after """Test creating ID gens with multiple tables that have rows from after
the position in `stream_positions` table. the position in `stream_positions` table.
""" """