mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-25 11:05:49 +03:00
Add missing type hints for tests.unittest. (#13397)
This commit is contained in:
parent
502f075e96
commit
922b771337
6 changed files with 66 additions and 52 deletions
1
changelog.d/13397.misc
Normal file
1
changelog.d/13397.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Adding missing type hints to tests.
|
|
@ -481,17 +481,13 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def prepare(
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
|
|
||||||
) -> HomeServer:
|
|
||||||
self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
|
self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
|
||||||
self.allowed_access_token = self.login(self.allowed_localpart, "pass")
|
self.allowed_access_token = self.login(self.allowed_localpart, "pass")
|
||||||
|
|
||||||
self.denied_user_id = self.register_user("denied", "pass")
|
self.denied_user_id = self.register_user("denied", "pass")
|
||||||
self.denied_access_token = self.login("denied", "pass")
|
self.denied_access_token = self.login("denied", "pass")
|
||||||
|
|
||||||
return hs
|
|
||||||
|
|
||||||
def test_denied_without_publication_permission(self) -> None:
|
def test_denied_without_publication_permission(self) -> None:
|
||||||
"""
|
"""
|
||||||
Try to create a room, register an alias for it, and publish it,
|
Try to create a room, register an alias for it, and publish it,
|
||||||
|
@ -575,9 +571,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
servlets = [directory.register_servlets, room.register_servlets]
|
servlets = [directory.register_servlets, room.register_servlets]
|
||||||
|
|
||||||
def prepare(
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
|
|
||||||
) -> HomeServer:
|
|
||||||
room_id = self.helper.create_room_as(self.user_id)
|
room_id = self.helper.create_room_as(self.user_id)
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
|
@ -588,8 +582,6 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
|
||||||
self.room_list_handler = hs.get_room_list_handler()
|
self.room_list_handler = hs.get_room_list_handler()
|
||||||
self.directory_handler = hs.get_directory_handler()
|
self.directory_handler = hs.get_directory_handler()
|
||||||
|
|
||||||
return hs
|
|
||||||
|
|
||||||
def test_disabling_room_list(self) -> None:
|
def test_disabling_room_list(self) -> None:
|
||||||
self.room_list_handler.enable_room_list_search = True
|
self.room_list_handler.enable_room_list_search = True
|
||||||
self.directory_handler.enable_room_list_search = True
|
self.directory_handler.enable_room_list_search = True
|
||||||
|
|
|
@ -1060,6 +1060,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
||||||
participated, bundled_aggregations.get("current_user_participated")
|
participated, bundled_aggregations.get("current_user_participated")
|
||||||
)
|
)
|
||||||
# The latest thread event has some fields that don't matter.
|
# The latest thread event has some fields that don't matter.
|
||||||
|
self.assertIn("latest_event", bundled_aggregations)
|
||||||
self.assert_dict(
|
self.assert_dict(
|
||||||
{
|
{
|
||||||
"content": {
|
"content": {
|
||||||
|
@ -1072,7 +1073,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
||||||
"sender": self.user2_id,
|
"sender": self.user2_id,
|
||||||
"type": "m.room.test",
|
"type": "m.room.test",
|
||||||
},
|
},
|
||||||
bundled_aggregations.get("latest_event"),
|
bundled_aggregations["latest_event"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return assert_thread
|
return assert_thread
|
||||||
|
@ -1112,6 +1113,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
||||||
self.assertEqual(2, bundled_aggregations.get("count"))
|
self.assertEqual(2, bundled_aggregations.get("count"))
|
||||||
self.assertTrue(bundled_aggregations.get("current_user_participated"))
|
self.assertTrue(bundled_aggregations.get("current_user_participated"))
|
||||||
# The latest thread event has some fields that don't matter.
|
# The latest thread event has some fields that don't matter.
|
||||||
|
self.assertIn("latest_event", bundled_aggregations)
|
||||||
self.assert_dict(
|
self.assert_dict(
|
||||||
{
|
{
|
||||||
"content": {
|
"content": {
|
||||||
|
@ -1124,7 +1126,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
||||||
"sender": self.user_id,
|
"sender": self.user_id,
|
||||||
"type": "m.room.test",
|
"type": "m.room.test",
|
||||||
},
|
},
|
||||||
bundled_aggregations.get("latest_event"),
|
bundled_aggregations["latest_event"],
|
||||||
)
|
)
|
||||||
# Check the unsigned field on the latest event.
|
# Check the unsigned field on the latest event.
|
||||||
self.assert_dict(
|
self.assert_dict(
|
||||||
|
|
|
@ -496,7 +496,7 @@ class RoomStateTestCase(RoomBase):
|
||||||
|
|
||||||
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
|
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
|
||||||
self.assertCountEqual(
|
self.assertCountEqual(
|
||||||
[state_event["type"] for state_event in channel.json_body],
|
[state_event["type"] for state_event in channel.json_list],
|
||||||
{
|
{
|
||||||
"m.room.create",
|
"m.room.create",
|
||||||
"m.room.power_levels",
|
"m.room.power_levels",
|
||||||
|
|
|
@ -25,6 +25,7 @@ from typing import (
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
|
List,
|
||||||
MutableMapping,
|
MutableMapping,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
@ -121,7 +122,15 @@ class FakeChannel:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def json_body(self) -> JsonDict:
|
def json_body(self) -> JsonDict:
|
||||||
return json.loads(self.text_body)
|
body = json.loads(self.text_body)
|
||||||
|
assert isinstance(body, dict)
|
||||||
|
return body
|
||||||
|
|
||||||
|
@property
|
||||||
|
def json_list(self) -> List[JsonDict]:
|
||||||
|
body = json.loads(self.text_body)
|
||||||
|
assert isinstance(body, list)
|
||||||
|
return body
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text_body(self) -> str:
|
def text_body(self) -> str:
|
||||||
|
|
|
@ -28,6 +28,7 @@ from typing import (
|
||||||
Generic,
|
Generic,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
|
NoReturn,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
|
@ -39,7 +40,7 @@ from unittest.mock import Mock, patch
|
||||||
import canonicaljson
|
import canonicaljson
|
||||||
import signedjson.key
|
import signedjson.key
|
||||||
import unpaddedbase64
|
import unpaddedbase64
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Concatenate, ParamSpec, Protocol
|
||||||
|
|
||||||
from twisted.internet.defer import Deferred, ensureDeferred
|
from twisted.internet.defer import Deferred, ensureDeferred
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
|
@ -67,7 +68,7 @@ from synapse.logging.context import (
|
||||||
from synapse.rest import RegisterServletsFunc
|
from synapse.rest import RegisterServletsFunc
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.keys import FetchKeyResult
|
from synapse.storage.keys import FetchKeyResult
|
||||||
from synapse.types import JsonDict, UserID, create_requester
|
from synapse.types import JsonDict, Requester, UserID, create_requester
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
from synapse.util.httpresourcetree import create_resource_tree
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
|
|
||||||
|
@ -88,6 +89,10 @@ setup_logging()
|
||||||
TV = TypeVar("TV")
|
TV = TypeVar("TV")
|
||||||
_ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True)
|
_ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True)
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
S = TypeVar("S")
|
||||||
|
|
||||||
|
|
||||||
class _TypedFailure(Generic[_ExcType], Protocol):
|
class _TypedFailure(Generic[_ExcType], Protocol):
|
||||||
"""Extension to twisted.Failure, where the 'value' has a certain type."""
|
"""Extension to twisted.Failure, where the 'value' has a certain type."""
|
||||||
|
@ -97,7 +102,7 @@ class _TypedFailure(Generic[_ExcType], Protocol):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
def around(target):
|
def around(target: TV) -> Callable[[Callable[Concatenate[S, P], R]], None]:
|
||||||
"""A CLOS-style 'around' modifier, which wraps the original method of the
|
"""A CLOS-style 'around' modifier, which wraps the original method of the
|
||||||
given instance with another piece of code.
|
given instance with another piece of code.
|
||||||
|
|
||||||
|
@ -106,11 +111,11 @@ def around(target):
|
||||||
return orig(*args, **kwargs)
|
return orig(*args, **kwargs)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _around(code):
|
def _around(code: Callable[Concatenate[S, P], R]) -> None:
|
||||||
name = code.__name__
|
name = code.__name__
|
||||||
orig = getattr(target, name)
|
orig = getattr(target, name)
|
||||||
|
|
||||||
def new(*args, **kwargs):
|
def new(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
return code(orig, *args, **kwargs)
|
return code(orig, *args, **kwargs)
|
||||||
|
|
||||||
setattr(target, name, new)
|
setattr(target, name, new)
|
||||||
|
@ -131,7 +136,7 @@ class TestCase(unittest.TestCase):
|
||||||
level = getattr(method, "loglevel", getattr(self, "loglevel", None))
|
level = getattr(method, "loglevel", getattr(self, "loglevel", None))
|
||||||
|
|
||||||
@around(self)
|
@around(self)
|
||||||
def setUp(orig):
|
def setUp(orig: Callable[[], R]) -> R:
|
||||||
# if we're not starting in the sentinel logcontext, then to be honest
|
# if we're not starting in the sentinel logcontext, then to be honest
|
||||||
# all future bets are off.
|
# all future bets are off.
|
||||||
if current_context():
|
if current_context():
|
||||||
|
@ -144,7 +149,7 @@ class TestCase(unittest.TestCase):
|
||||||
if level is not None and old_level != level:
|
if level is not None and old_level != level:
|
||||||
|
|
||||||
@around(self)
|
@around(self)
|
||||||
def tearDown(orig):
|
def tearDown(orig: Callable[[], R]) -> R:
|
||||||
ret = orig()
|
ret = orig()
|
||||||
logging.getLogger().setLevel(old_level)
|
logging.getLogger().setLevel(old_level)
|
||||||
return ret
|
return ret
|
||||||
|
@ -158,7 +163,7 @@ class TestCase(unittest.TestCase):
|
||||||
return orig()
|
return orig()
|
||||||
|
|
||||||
@around(self)
|
@around(self)
|
||||||
def tearDown(orig):
|
def tearDown(orig: Callable[[], R]) -> R:
|
||||||
ret = orig()
|
ret = orig()
|
||||||
# force a GC to workaround problems with deferreds leaking logcontexts when
|
# force a GC to workaround problems with deferreds leaking logcontexts when
|
||||||
# they are GCed (see the logcontext docs)
|
# they are GCed (see the logcontext docs)
|
||||||
|
@ -167,7 +172,7 @@ class TestCase(unittest.TestCase):
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def assertObjectHasAttributes(self, attrs, obj):
|
def assertObjectHasAttributes(self, attrs: Dict[str, object], obj: object) -> None:
|
||||||
"""Asserts that the given object has each of the attributes given, and
|
"""Asserts that the given object has each of the attributes given, and
|
||||||
that the value of each matches according to assertEqual."""
|
that the value of each matches according to assertEqual."""
|
||||||
for key in attrs.keys():
|
for key in attrs.keys():
|
||||||
|
@ -178,12 +183,12 @@ class TestCase(unittest.TestCase):
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
raise (type(e))(f"Assert error for '.{key}':") from e
|
raise (type(e))(f"Assert error for '.{key}':") from e
|
||||||
|
|
||||||
def assert_dict(self, required, actual):
|
def assert_dict(self, required: dict, actual: dict) -> None:
|
||||||
"""Does a partial assert of a dict.
|
"""Does a partial assert of a dict.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
required (dict): The keys and value which MUST be in 'actual'.
|
required: The keys and value which MUST be in 'actual'.
|
||||||
actual (dict): The test result. Extra keys will not be checked.
|
actual: The test result. Extra keys will not be checked.
|
||||||
"""
|
"""
|
||||||
for key in required:
|
for key in required:
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
@ -191,31 +196,31 @@ class TestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def DEBUG(target):
|
def DEBUG(target: TV) -> TV:
|
||||||
"""A decorator to set the .loglevel attribute to logging.DEBUG.
|
"""A decorator to set the .loglevel attribute to logging.DEBUG.
|
||||||
Can apply to either a TestCase or an individual test method."""
|
Can apply to either a TestCase or an individual test method."""
|
||||||
target.loglevel = logging.DEBUG
|
target.loglevel = logging.DEBUG # type: ignore[attr-defined]
|
||||||
return target
|
return target
|
||||||
|
|
||||||
|
|
||||||
def INFO(target):
|
def INFO(target: TV) -> TV:
|
||||||
"""A decorator to set the .loglevel attribute to logging.INFO.
|
"""A decorator to set the .loglevel attribute to logging.INFO.
|
||||||
Can apply to either a TestCase or an individual test method."""
|
Can apply to either a TestCase or an individual test method."""
|
||||||
target.loglevel = logging.INFO
|
target.loglevel = logging.INFO # type: ignore[attr-defined]
|
||||||
return target
|
return target
|
||||||
|
|
||||||
|
|
||||||
def logcontext_clean(target):
|
def logcontext_clean(target: TV) -> TV:
|
||||||
"""A decorator which marks the TestCase or method as 'logcontext_clean'
|
"""A decorator which marks the TestCase or method as 'logcontext_clean'
|
||||||
|
|
||||||
... ie, any logcontext errors should cause a test failure
|
... ie, any logcontext errors should cause a test failure
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def logcontext_error(msg):
|
def logcontext_error(msg: str) -> NoReturn:
|
||||||
raise AssertionError("logcontext error: %s" % (msg))
|
raise AssertionError("logcontext error: %s" % (msg))
|
||||||
|
|
||||||
patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error)
|
patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error)
|
||||||
return patcher(target)
|
return patcher(target) # type: ignore[call-overload]
|
||||||
|
|
||||||
|
|
||||||
class HomeserverTestCase(TestCase):
|
class HomeserverTestCase(TestCase):
|
||||||
|
@ -255,7 +260,7 @@ class HomeserverTestCase(TestCase):
|
||||||
method = getattr(self, methodName)
|
method = getattr(self, methodName)
|
||||||
self._extra_config = getattr(method, "_extra_config", None)
|
self._extra_config = getattr(method, "_extra_config", None)
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self) -> None:
|
||||||
"""
|
"""
|
||||||
Set up the TestCase by calling the homeserver constructor, optionally
|
Set up the TestCase by calling the homeserver constructor, optionally
|
||||||
hijacking the authentication system to return a fixed user, and then
|
hijacking the authentication system to return a fixed user, and then
|
||||||
|
@ -306,7 +311,9 @@ class HomeserverTestCase(TestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_user_by_access_token(token=None, allow_guest=False):
|
async def get_user_by_access_token(
|
||||||
|
token: Optional[str] = None, allow_guest: bool = False
|
||||||
|
) -> JsonDict:
|
||||||
assert self.helper.auth_user_id is not None
|
assert self.helper.auth_user_id is not None
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(self.helper.auth_user_id),
|
"user": UserID.from_string(self.helper.auth_user_id),
|
||||||
|
@ -314,7 +321,11 @@ class HomeserverTestCase(TestCase):
|
||||||
"is_guest": False,
|
"is_guest": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def get_user_by_req(request, allow_guest=False):
|
async def get_user_by_req(
|
||||||
|
request: SynapseRequest,
|
||||||
|
allow_guest: bool = False,
|
||||||
|
allow_expired: bool = False,
|
||||||
|
) -> Requester:
|
||||||
assert self.helper.auth_user_id is not None
|
assert self.helper.auth_user_id is not None
|
||||||
return create_requester(
|
return create_requester(
|
||||||
UserID.from_string(self.helper.auth_user_id),
|
UserID.from_string(self.helper.auth_user_id),
|
||||||
|
@ -339,11 +350,11 @@ class HomeserverTestCase(TestCase):
|
||||||
if hasattr(self, "prepare"):
|
if hasattr(self, "prepare"):
|
||||||
self.prepare(self.reactor, self.clock, self.hs)
|
self.prepare(self.reactor, self.clock, self.hs)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self) -> None:
|
||||||
# Reset to not use frozen dicts.
|
# Reset to not use frozen dicts.
|
||||||
events.USE_FROZEN_DICTS = False
|
events.USE_FROZEN_DICTS = False
|
||||||
|
|
||||||
def wait_on_thread(self, deferred, timeout=10):
|
def wait_on_thread(self, deferred: Deferred, timeout: int = 10) -> None:
|
||||||
"""
|
"""
|
||||||
Wait until a Deferred is done, where it's waiting on a real thread.
|
Wait until a Deferred is done, where it's waiting on a real thread.
|
||||||
"""
|
"""
|
||||||
|
@ -374,7 +385,7 @@ class HomeserverTestCase(TestCase):
|
||||||
clock (synapse.util.Clock): The Clock, associated with the reactor.
|
clock (synapse.util.Clock): The Clock, associated with the reactor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A homeserver (synapse.server.HomeServer) suitable for testing.
|
A homeserver suitable for testing.
|
||||||
|
|
||||||
Function to be overridden in subclasses.
|
Function to be overridden in subclasses.
|
||||||
"""
|
"""
|
||||||
|
@ -408,7 +419,7 @@ class HomeserverTestCase(TestCase):
|
||||||
"/_synapse/admin": servlet_resource,
|
"/_synapse/admin": servlet_resource,
|
||||||
}
|
}
|
||||||
|
|
||||||
def default_config(self):
|
def default_config(self) -> JsonDict:
|
||||||
"""
|
"""
|
||||||
Get a default HomeServer config dict.
|
Get a default HomeServer config dict.
|
||||||
"""
|
"""
|
||||||
|
@ -421,7 +432,9 @@ class HomeserverTestCase(TestCase):
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
|
def prepare(
|
||||||
|
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Prepare for the test. This involves things like mocking out parts of
|
Prepare for the test. This involves things like mocking out parts of
|
||||||
the homeserver, or building test data common across the whole test
|
the homeserver, or building test data common across the whole test
|
||||||
|
@ -519,7 +532,7 @@ class HomeserverTestCase(TestCase):
|
||||||
config_obj.parse_config_dict(config, "", "")
|
config_obj.parse_config_dict(config, "", "")
|
||||||
kwargs["config"] = config_obj
|
kwargs["config"] = config_obj
|
||||||
|
|
||||||
async def run_bg_updates():
|
async def run_bg_updates() -> None:
|
||||||
with LoggingContext("run_bg_updates"):
|
with LoggingContext("run_bg_updates"):
|
||||||
self.get_success(stor.db_pool.updates.run_background_updates(False))
|
self.get_success(stor.db_pool.updates.run_background_updates(False))
|
||||||
|
|
||||||
|
@ -538,11 +551,7 @@ class HomeserverTestCase(TestCase):
|
||||||
"""
|
"""
|
||||||
self.reactor.pump([by] * 100)
|
self.reactor.pump([by] * 100)
|
||||||
|
|
||||||
def get_success(
|
def get_success(self, d: Awaitable[TV], by: float = 0.0) -> TV:
|
||||||
self,
|
|
||||||
d: Awaitable[TV],
|
|
||||||
by: float = 0.0,
|
|
||||||
) -> TV:
|
|
||||||
deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]
|
deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]
|
||||||
self.pump(by=by)
|
self.pump(by=by)
|
||||||
return self.successResultOf(deferred)
|
return self.successResultOf(deferred)
|
||||||
|
@ -755,7 +764,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
|
||||||
OTHER_SERVER_NAME = "other.example.com"
|
OTHER_SERVER_NAME = "other.example.com"
|
||||||
OTHER_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test")
|
OTHER_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test")
|
||||||
|
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
super().prepare(reactor, clock, hs)
|
super().prepare(reactor, clock, hs)
|
||||||
|
|
||||||
# poke the other server's signing key into the key store, so that we don't
|
# poke the other server's signing key into the key store, so that we don't
|
||||||
|
@ -879,7 +888,7 @@ def _auth_header_for_request(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def override_config(extra_config):
|
def override_config(extra_config: JsonDict) -> Callable[[TV], TV]:
|
||||||
"""A decorator which can be applied to test functions to give additional HS config
|
"""A decorator which can be applied to test functions to give additional HS config
|
||||||
|
|
||||||
For use
|
For use
|
||||||
|
@ -892,12 +901,13 @@ def override_config(extra_config):
|
||||||
...
|
...
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
extra_config(dict): Additional config settings to be merged into the default
|
extra_config: Additional config settings to be merged into the default
|
||||||
config dict before instantiating the test homeserver.
|
config dict before instantiating the test homeserver.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func):
|
def decorator(func: TV) -> TV:
|
||||||
func._extra_config = extra_config
|
# This attribute is being defined.
|
||||||
|
func._extra_config = extra_config # type: ignore[attr-defined]
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
Loading…
Reference in a new issue