Add type hints for tests/unittest.py. (#12347)

In particular, add type hints for get_success and friends, which are then helpful in a bunch of places.
This commit is contained in:
Richard van der Hoff 2022-04-01 17:04:16 +01:00 committed by GitHub
parent 33ebee47e4
commit f0b03186d9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 97 additions and 48 deletions

1
changelog.d/12347.misc Normal file
View file

@ -0,0 +1 @@
Add type annotations for `tests/unittest.py`.

View file

@ -83,7 +83,6 @@ exclude = (?x)
|tests/test_server.py |tests/test_server.py
|tests/test_state.py |tests/test_state.py
|tests/test_terms_auth.py |tests/test_terms_auth.py
|tests/unittest.py
|tests/util/caches/test_cached_call.py |tests/util/caches/test_cached_call.py
|tests/util/caches/test_deferred_cache.py |tests/util/caches/test_deferred_cache.py
|tests/util/caches/test_descriptors.py |tests/util/caches/test_descriptors.py

View file

@ -463,8 +463,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code res = e.value.code
self.assertEqual(res, 400) self.assertEqual(res, 400)
res = self.get_success(self.handler.query_local_devices({local_user: None})) query_res = self.get_success(
self.assertDictEqual(res, {local_user: {}}) self.handler.query_local_devices({local_user: None})
)
self.assertDictEqual(query_res, {local_user: {}})
def test_upload_signatures(self) -> None: def test_upload_signatures(self) -> None:
"""should check signatures that are uploaded""" """should check signatures that are uploaded"""

View file

@ -375,7 +375,8 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
member_event.signatures = member_event_dict["signatures"] member_event.signatures = member_event_dict["signatures"]
# Add the new member_event to the StateMap # Add the new member_event to the StateMap
prev_state_map[ updated_state_map = dict(prev_state_map)
updated_state_map[
(member_event.type, member_event.state_key) (member_event.type, member_event.state_key)
] = member_event.event_id ] = member_event.event_id
auth_events.append(member_event) auth_events.append(member_event)
@ -399,7 +400,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
prev_event_ids=message_event_dict["prev_events"], prev_event_ids=message_event_dict["prev_events"],
auth_event_ids=self._event_auth_handler.compute_auth_events( auth_event_ids=self._event_auth_handler.compute_auth_events(
builder, builder,
prev_state_map, updated_state_map,
for_verification=False, for_verification=False,
), ),
depth=message_event_dict["depth"], depth=message_event_dict["depth"],

View file

@ -354,10 +354,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
req = Mock(spec=["cookies"]) req = Mock(spec=["cookies"])
req.cookies = [] req.cookies = []
url = self.get_success( url = urlparse(
self.provider.handle_redirect_request(req, b"http://client/redirect") self.get_success(
self.provider.handle_redirect_request(req, b"http://client/redirect")
)
) )
url = urlparse(url)
auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
self.assertEqual(url.scheme, auth_endpoint.scheme) self.assertEqual(url.scheme, auth_endpoint.scheme)

View file

@ -351,6 +351,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.handler.handle_local_profile_change(regular_user_id, profile_info) self.handler.handle_local_profile_change(regular_user_id, profile_info)
) )
profile = self.get_success(self.store.get_user_in_directory(regular_user_id)) profile = self.get_success(self.store.get_user_in_directory(regular_user_id))
assert profile is not None
self.assertTrue(profile["display_name"] == display_name) self.assertTrue(profile["display_name"] == display_name)
def test_handle_local_profile_change_with_deactivated_user(self) -> None: def test_handle_local_profile_change_with_deactivated_user(self) -> None:
@ -369,6 +370,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# profile is in directory # profile is in directory
profile = self.get_success(self.store.get_user_in_directory(r_user_id)) profile = self.get_success(self.store.get_user_in_directory(r_user_id))
assert profile is not None
self.assertTrue(profile["display_name"] == display_name) self.assertTrue(profile["display_name"] == display_name)
# deactivate user # deactivate user

View file

@ -702,6 +702,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
""" """
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["quarantined_by"]) self.assertFalse(media_info["quarantined_by"])
# quarantining # quarantining
@ -715,6 +716,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertFalse(channel.json_body) self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertTrue(media_info["quarantined_by"]) self.assertTrue(media_info["quarantined_by"])
# remove from quarantine # remove from quarantine
@ -728,6 +730,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertFalse(channel.json_body) self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["quarantined_by"]) self.assertFalse(media_info["quarantined_by"])
def test_quarantine_protected_media(self) -> None: def test_quarantine_protected_media(self) -> None:
@ -740,6 +743,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
# verify protection # verify protection
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertTrue(media_info["safe_from_quarantine"]) self.assertTrue(media_info["safe_from_quarantine"])
# quarantining # quarantining
@ -754,6 +758,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
# verify that is not in quarantine # verify that is not in quarantine
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["quarantined_by"]) self.assertFalse(media_info["quarantined_by"])
@ -830,6 +835,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
""" """
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["safe_from_quarantine"]) self.assertFalse(media_info["safe_from_quarantine"])
# protect # protect
@ -843,6 +849,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertFalse(channel.json_body) self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertTrue(media_info["safe_from_quarantine"]) self.assertTrue(media_info["safe_from_quarantine"])
# unprotect # unprotect
@ -856,6 +863,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertFalse(channel.json_body) self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
self.assertFalse(media_info["safe_from_quarantine"]) self.assertFalse(media_info["safe_from_quarantine"])

View file

@ -1590,10 +1590,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
pushers = self.get_success( pushers = list(
self.store.get_pushers_by({"user_name": "@bob:test"}) self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"}))
) )
pushers = list(pushers)
self.assertEqual(len(pushers), 1) self.assertEqual(len(pushers), 1)
self.assertEqual("@bob:test", pushers[0].user_name) self.assertEqual("@bob:test", pushers[0].user_name)
@ -1632,10 +1631,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
pushers = self.get_success( pushers = list(
self.store.get_pushers_by({"user_name": "@bob:test"}) self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"}))
) )
pushers = list(pushers)
self.assertEqual(len(pushers), 0) self.assertEqual(len(pushers), 0)
def test_set_password(self) -> None: def test_set_password(self) -> None:
@ -2144,6 +2142,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# is in user directory # is in user directory
profile = self.get_success(self.store.get_user_in_directory(self.other_user)) profile = self.get_success(self.store.get_user_in_directory(self.other_user))
assert profile is not None
self.assertTrue(profile["display_name"] == "User") self.assertTrue(profile["display_name"] == "User")
# Deactivate user # Deactivate user
@ -2711,6 +2710,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.store.get_user_by_access_token(other_user_token) self.store.get_user_by_access_token(other_user_token)
) )
assert user_tuple is not None
token_id = user_tuple.token_id token_id = user_tuple.token_id
self.get_success( self.get_success(
@ -3676,6 +3676,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
# The user starts off as not shadow-banned. # The user starts off as not shadow-banned.
other_user_token = self.login("user", "pass") other_user_token = self.login("user", "pass")
result = self.get_success(self.store.get_user_by_access_token(other_user_token)) result = self.get_success(self.store.get_user_by_access_token(other_user_token))
assert result is not None
self.assertFalse(result.shadow_banned) self.assertFalse(result.shadow_banned)
channel = self.make_request("POST", self.url, access_token=self.admin_user_tok) channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
@ -3684,6 +3685,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
# Ensure the user is shadow-banned (and the cache was cleared). # Ensure the user is shadow-banned (and the cache was cleared).
result = self.get_success(self.store.get_user_by_access_token(other_user_token)) result = self.get_success(self.store.get_user_by_access_token(other_user_token))
assert result is not None
self.assertTrue(result.shadow_banned) self.assertTrue(result.shadow_banned)
# Un-shadow-ban the user. # Un-shadow-ban the user.
@ -3695,6 +3697,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
# Ensure the user is no longer shadow-banned (and the cache was cleared). # Ensure the user is no longer shadow-banned (and the cache was cleared).
result = self.get_success(self.store.get_user_by_access_token(other_user_token)) result = self.get_success(self.store.get_user_by_access_token(other_user_token))
assert result is not None
self.assertFalse(result.shadow_banned) self.assertFalse(result.shadow_banned)

View file

@ -22,7 +22,6 @@ import warnings
from collections import deque from collections import deque
from io import SEEK_END, BytesIO from io import SEEK_END, BytesIO
from typing import ( from typing import (
AnyStr,
Callable, Callable,
Dict, Dict,
Iterable, Iterable,
@ -86,6 +85,9 @@ from tests.utils import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# the type of thing that can be passed into `make_request` in the headers list
CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
class TimedOutException(Exception): class TimedOutException(Exception):
""" """
@ -260,7 +262,7 @@ def make_request(
federation_auth_origin: Optional[bytes] = None, federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False, content_is_form: bool = False,
await_result: bool = True, await_result: bool = True,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, custom_headers: Optional[Iterable[CustomHeaderType]] = None,
client_ip: str = "127.0.0.1", client_ip: str = "127.0.0.1",
) -> FakeChannel: ) -> FakeChannel:
""" """

View file

@ -28,7 +28,7 @@ class LockTestCase(unittest.HomeserverTestCase):
""" """
# First to acquire this lock, so it should complete # First to acquire this lock, so it should complete
lock = self.get_success(self.store.try_acquire_lock("name", "key")) lock = self.get_success(self.store.try_acquire_lock("name", "key"))
self.assertIsNotNone(lock) assert lock is not None
# Enter the context manager # Enter the context manager
self.get_success(lock.__aenter__()) self.get_success(lock.__aenter__())
@ -45,7 +45,7 @@ class LockTestCase(unittest.HomeserverTestCase):
# We can now acquire the lock again. # We can now acquire the lock again.
lock3 = self.get_success(self.store.try_acquire_lock("name", "key")) lock3 = self.get_success(self.store.try_acquire_lock("name", "key"))
self.assertIsNotNone(lock3) assert lock3 is not None
self.get_success(lock3.__aenter__()) self.get_success(lock3.__aenter__())
self.get_success(lock3.__aexit__(None, None, None)) self.get_success(lock3.__aexit__(None, None, None))
@ -53,7 +53,7 @@ class LockTestCase(unittest.HomeserverTestCase):
"""Test that we don't time out locks while they're still active""" """Test that we don't time out locks while they're still active"""
lock = self.get_success(self.store.try_acquire_lock("name", "key")) lock = self.get_success(self.store.try_acquire_lock("name", "key"))
self.assertIsNotNone(lock) assert lock is not None
self.get_success(lock.__aenter__()) self.get_success(lock.__aenter__())
@ -69,7 +69,7 @@ class LockTestCase(unittest.HomeserverTestCase):
"""Test that we time out locks if they're not updated for ages""" """Test that we time out locks if they're not updated for ages"""
lock = self.get_success(self.store.try_acquire_lock("name", "key")) lock = self.get_success(self.store.try_acquire_lock("name", "key"))
self.assertIsNotNone(lock) assert lock is not None
self.get_success(lock.__aenter__()) self.get_success(lock.__aenter__())

View file

@ -358,6 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self._insert_txn(service.id, 12, other_events)) self.get_success(self._insert_txn(service.id, 12, other_events))
txn = self.get_success(self.store.get_oldest_unsent_txn(service)) txn = self.get_success(self.store.get_oldest_unsent_txn(service))
assert txn is not None
self.assertEqual(service, txn.service) self.assertEqual(service, txn.service)
self.assertEqual(10, txn.id) self.assertEqual(10, txn.id)
self.assertEqual(events, txn.events) self.assertEqual(events, txn.events)

View file

@ -22,10 +22,11 @@ import secrets
import time import time
from typing import ( from typing import (
Any, Any,
AnyStr, Awaitable,
Callable, Callable,
ClassVar, ClassVar,
Dict, Dict,
Generic,
Iterable, Iterable,
List, List,
Optional, Optional,
@ -39,6 +40,7 @@ from unittest.mock import Mock, patch
import canonicaljson import canonicaljson
import signedjson.key import signedjson.key
import unpaddedbase64 import unpaddedbase64
from typing_extensions import Protocol
from twisted.internet.defer import Deferred, ensureDeferred from twisted.internet.defer import Deferred, ensureDeferred
from twisted.python.failure import Failure from twisted.python.failure import Failure
@ -49,7 +51,7 @@ from twisted.web.resource import Resource
from twisted.web.server import Request from twisted.web.server import Request
from synapse import events from synapse import events
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.config.server import DEFAULT_ROOM_VERSION
@ -70,7 +72,13 @@ from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver from tests.server import (
CustomHeaderType,
FakeChannel,
get_clock,
make_request,
setup_test_homeserver,
)
from tests.test_utils import event_injection, setup_awaitable_errors from tests.test_utils import event_injection, setup_awaitable_errors
from tests.test_utils.logging_setup import setup_logging from tests.test_utils.logging_setup import setup_logging
from tests.utils import default_config, setupdb from tests.utils import default_config, setupdb
@ -78,6 +86,17 @@ from tests.utils import default_config, setupdb
setupdb() setupdb()
setup_logging() setup_logging()
TV = TypeVar("TV")
_ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True)
class _TypedFailure(Generic[_ExcType], Protocol):
"""Extension to twisted.Failure, where the 'value' has a certain type."""
@property
def value(self) -> _ExcType:
...
def around(target): def around(target):
"""A CLOS-style 'around' modifier, which wraps the original method of the """A CLOS-style 'around' modifier, which wraps the original method of the
@ -276,6 +295,7 @@ class HomeserverTestCase(TestCase):
if hasattr(self, "user_id"): if hasattr(self, "user_id"):
if self.hijack_auth: if self.hijack_auth:
assert self.helper.auth_user_id is not None
# We need a valid token ID to satisfy foreign key constraints. # We need a valid token ID to satisfy foreign key constraints.
token_id = self.get_success( token_id = self.get_success(
@ -288,6 +308,7 @@ class HomeserverTestCase(TestCase):
) )
async def get_user_by_access_token(token=None, allow_guest=False): async def get_user_by_access_token(token=None, allow_guest=False):
assert self.helper.auth_user_id is not None
return { return {
"user": UserID.from_string(self.helper.auth_user_id), "user": UserID.from_string(self.helper.auth_user_id),
"token_id": token_id, "token_id": token_id,
@ -295,6 +316,7 @@ class HomeserverTestCase(TestCase):
} }
async def get_user_by_req(request, allow_guest=False, rights="access"): async def get_user_by_req(request, allow_guest=False, rights="access"):
assert self.helper.auth_user_id is not None
return create_requester( return create_requester(
UserID.from_string(self.helper.auth_user_id), UserID.from_string(self.helper.auth_user_id),
token_id, token_id,
@ -311,7 +333,7 @@ class HomeserverTestCase(TestCase):
) )
if self.needs_threadpool: if self.needs_threadpool:
self.reactor.threadpool = ThreadPool() self.reactor.threadpool = ThreadPool() # type: ignore[assignment]
self.addCleanup(self.reactor.threadpool.stop) self.addCleanup(self.reactor.threadpool.stop)
self.reactor.threadpool.start() self.reactor.threadpool.start()
@ -426,7 +448,7 @@ class HomeserverTestCase(TestCase):
federation_auth_origin: Optional[bytes] = None, federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False, content_is_form: bool = False,
await_result: bool = True, await_result: bool = True,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, custom_headers: Optional[Iterable[CustomHeaderType]] = None,
client_ip: str = "127.0.0.1", client_ip: str = "127.0.0.1",
) -> FakeChannel: ) -> FakeChannel:
""" """
@ -511,30 +533,36 @@ class HomeserverTestCase(TestCase):
return hs return hs
def pump(self, by=0.0): def pump(self, by: float = 0.0) -> None:
""" """
Pump the reactor enough that Deferreds will fire. Pump the reactor enough that Deferreds will fire.
""" """
self.reactor.pump([by] * 100) self.reactor.pump([by] * 100)
def get_success(self, d, by=0.0): def get_success(
deferred: Deferred[TV] = ensureDeferred(d) self,
d: Awaitable[TV],
by: float = 0.0,
) -> TV:
deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]
self.pump(by=by) self.pump(by=by)
return self.successResultOf(deferred) return self.successResultOf(deferred)
def get_failure(self, d, exc): def get_failure(
self, d: Awaitable[Any], exc: Type[_ExcType]
) -> _TypedFailure[_ExcType]:
""" """
Run a Deferred and get a Failure from it. The failure must be of the type `exc`. Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
""" """
deferred: Deferred[Any] = ensureDeferred(d) deferred: Deferred[Any] = ensureDeferred(d) # type: ignore[arg-type]
self.pump() self.pump()
return self.failureResultOf(deferred, exc) return self.failureResultOf(deferred, exc)
def get_success_or_raise(self, d, by=0.0): def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV:
"""Drive deferred to completion and return result or raise exception """Drive deferred to completion and return result or raise exception
on failure. on failure.
""" """
deferred: Deferred[TV] = ensureDeferred(d) deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]
results: list = [] results: list = []
deferred.addBoth(results.append) deferred.addBoth(results.append)
@ -642,11 +670,11 @@ class HomeserverTestCase(TestCase):
def login( def login(
self, self,
username, username: str,
password, password: str,
device_id=None, device_id: Optional[str] = None,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, custom_headers: Optional[Iterable[CustomHeaderType]] = None,
): ) -> str:
""" """
Log in a user, and get an access token. Requires the Login API be Log in a user, and get an access token. Requires the Login API be
registered. registered.
@ -668,18 +696,22 @@ class HomeserverTestCase(TestCase):
return access_token return access_token
def create_and_send_event( def create_and_send_event(
self, room_id, user, soft_failed=False, prev_event_ids=None self,
): room_id: str,
user: UserID,
soft_failed: bool = False,
prev_event_ids: Optional[List[str]] = None,
) -> str:
""" """
Create and send an event. Create and send an event.
Args: Args:
soft_failed (bool): Whether to create a soft failed event or not soft_failed: Whether to create a soft failed event or not
prev_event_ids (list[str]|None): Explicitly set the prev events, prev_event_ids: Explicitly set the prev events,
or if None just use the default or if None just use the default
Returns: Returns:
str: The new event's ID. The new event's ID.
""" """
event_creator = self.hs.get_event_creation_handler() event_creator = self.hs.get_event_creation_handler()
requester = create_requester(user) requester = create_requester(user)
@ -706,7 +738,7 @@ class HomeserverTestCase(TestCase):
return event.event_id return event.event_id
def inject_room_member(self, room: str, user: str, membership: Membership) -> None: def inject_room_member(self, room: str, user: str, membership: str) -> None:
""" """
Inject a membership event into a room. Inject a membership event into a room.
@ -766,7 +798,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
path: str, path: str,
content: Optional[JsonDict] = None, content: Optional[JsonDict] = None,
await_result: bool = True, await_result: bool = True,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, custom_headers: Optional[Iterable[CustomHeaderType]] = None,
client_ip: str = "127.0.0.1", client_ip: str = "127.0.0.1",
) -> FakeChannel: ) -> FakeChannel:
"""Make an inbound signed federation request to this server """Make an inbound signed federation request to this server
@ -799,7 +831,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
self.site, self.site,
method=method, method=method,
path=path, path=path,
content=content, content=content or "",
shorthand=False, shorthand=False,
await_result=await_result, await_result=await_result,
custom_headers=custom_headers, custom_headers=custom_headers,
@ -878,9 +910,6 @@ def override_config(extra_config):
return decorator return decorator
TV = TypeVar("TV")
def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]: def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]:
"""A test decorator which will skip the decorated test unless a condition is set """A test decorator which will skip the decorated test unless a condition is set