Add type hints to some tests/handlers files. (#12224)

This commit is contained in:
Dirk Klimpel 2022-03-15 14:16:37 +01:00 committed by GitHub
parent 2fcf4b3f6c
commit 5dd949bee6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 156 additions and 131 deletions

1
changelog.d/12224.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to tests files.

View file

@ -67,13 +67,8 @@ exclude = (?x)
|tests/federation/transport/test_knocking.py |tests/federation/transport/test_knocking.py
|tests/federation/transport/test_server.py |tests/federation/transport/test_server.py
|tests/handlers/test_cas.py |tests/handlers/test_cas.py
|tests/handlers/test_directory.py
|tests/handlers/test_e2e_keys.py
|tests/handlers/test_federation.py |tests/handlers/test_federation.py
|tests/handlers/test_oidc.py
|tests/handlers/test_presence.py |tests/handlers/test_presence.py
|tests/handlers/test_profile.py
|tests/handlers/test_saml.py
|tests/handlers/test_typing.py |tests/handlers/test_typing.py
|tests/http/federation/test_matrix_federation_agent.py |tests/http/federation/test_matrix_federation_agent.py
|tests/http/federation/test_srv_resolver.py |tests/http/federation/test_srv_resolver.py

View file

@ -12,14 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Awaitable, Callable, Dict
from unittest.mock import Mock from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
import synapse.api.errors import synapse.api.errors
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.rest.client import directory, login, room from synapse.rest.client import directory, login, room
from synapse.types import RoomAlias, create_requester from synapse.server import HomeServer
from synapse.types import JsonDict, RoomAlias, create_requester
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
@ -28,13 +32,15 @@ from tests.test_utils import make_awaitable
class DirectoryTestCase(unittest.HomeserverTestCase): class DirectoryTestCase(unittest.HomeserverTestCase):
"""Tests the directory service.""" """Tests the directory service."""
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.mock_federation = Mock() self.mock_federation = Mock()
self.mock_registry = Mock() self.mock_registry = Mock()
self.query_handlers = {} self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
def register_query_handler(query_type, handler): def register_query_handler(
query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
) -> None:
self.query_handlers[query_type] = handler self.query_handlers[query_type] = handler
self.mock_registry.register_query_handler = register_query_handler self.mock_registry.register_query_handler = register_query_handler
@ -54,7 +60,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
return hs return hs
def test_get_local_association(self): def test_get_local_association(self) -> None:
self.get_success( self.get_success(
self.store.create_room_alias_association( self.store.create_room_alias_association(
self.my_room, "!8765qwer:test", ["test"] self.my_room, "!8765qwer:test", ["test"]
@ -65,7 +71,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result) self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
def test_get_remote_association(self): def test_get_remote_association(self) -> None:
self.mock_federation.make_query.return_value = make_awaitable( self.mock_federation.make_query.return_value = make_awaitable(
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]} {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
) )
@ -83,7 +89,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
ignore_backoff=True, ignore_backoff=True,
) )
def test_incoming_fed_query(self): def test_incoming_fed_query(self) -> None:
self.get_success( self.get_success(
self.store.create_room_alias_association( self.store.create_room_alias_association(
self.your_room, "!8765asdf:test", ["test"] self.your_room, "!8765asdf:test", ["test"]
@ -105,7 +111,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
directory.register_servlets, directory.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_directory_handler() self.handler = hs.get_directory_handler()
# Create user # Create user
@ -125,7 +131,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
self.test_user_tok = self.login("user", "pass") self.test_user_tok = self.login("user", "pass")
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
def test_create_alias_joined_room(self): def test_create_alias_joined_room(self) -> None:
"""A user can create an alias for a room they're in.""" """A user can create an alias for a room they're in."""
self.get_success( self.get_success(
self.handler.create_association( self.handler.create_association(
@ -135,7 +141,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
) )
) )
def test_create_alias_other_room(self): def test_create_alias_other_room(self) -> None:
"""A user cannot create an alias for a room they're NOT in.""" """A user cannot create an alias for a room they're NOT in."""
other_room_id = self.helper.create_room_as( other_room_id = self.helper.create_room_as(
self.admin_user, tok=self.admin_user_tok self.admin_user, tok=self.admin_user_tok
@ -150,7 +156,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
synapse.api.errors.SynapseError, synapse.api.errors.SynapseError,
) )
def test_create_alias_admin(self): def test_create_alias_admin(self) -> None:
"""An admin can create an alias for a room they're NOT in.""" """An admin can create an alias for a room they're NOT in."""
other_room_id = self.helper.create_room_as( other_room_id = self.helper.create_room_as(
self.test_user, tok=self.test_user_tok self.test_user, tok=self.test_user_tok
@ -173,7 +179,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
directory.register_servlets, directory.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.handler = hs.get_directory_handler() self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
@ -195,7 +201,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
self.test_user_tok = self.login("user", "pass") self.test_user_tok = self.login("user", "pass")
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
def _create_alias(self, user): def _create_alias(self, user) -> None:
# Create a new alias to this room. # Create a new alias to this room.
self.get_success( self.get_success(
self.store.create_room_alias_association( self.store.create_room_alias_association(
@ -203,7 +209,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
) )
) )
def test_delete_alias_not_allowed(self): def test_delete_alias_not_allowed(self) -> None:
"""A user that doesn't meet the expected guidelines cannot delete an alias.""" """A user that doesn't meet the expected guidelines cannot delete an alias."""
self._create_alias(self.admin_user) self._create_alias(self.admin_user)
self.get_failure( self.get_failure(
@ -213,7 +219,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
synapse.api.errors.AuthError, synapse.api.errors.AuthError,
) )
def test_delete_alias_creator(self): def test_delete_alias_creator(self) -> None:
"""An alias creator can delete their own alias.""" """An alias creator can delete their own alias."""
# Create an alias from a different user. # Create an alias from a different user.
self._create_alias(self.test_user) self._create_alias(self.test_user)
@ -232,7 +238,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
synapse.api.errors.SynapseError, synapse.api.errors.SynapseError,
) )
def test_delete_alias_admin(self): def test_delete_alias_admin(self) -> None:
"""A server admin can delete an alias created by another user.""" """A server admin can delete an alias created by another user."""
# Create an alias from a different user. # Create an alias from a different user.
self._create_alias(self.test_user) self._create_alias(self.test_user)
@ -251,7 +257,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
synapse.api.errors.SynapseError, synapse.api.errors.SynapseError,
) )
def test_delete_alias_sufficient_power(self): def test_delete_alias_sufficient_power(self) -> None:
"""A user with a sufficient power level should be able to delete an alias.""" """A user with a sufficient power level should be able to delete an alias."""
self._create_alias(self.admin_user) self._create_alias(self.admin_user)
@ -288,7 +294,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
directory.register_servlets, directory.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.handler = hs.get_directory_handler() self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
@ -317,7 +323,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
) )
return room_alias return room_alias
def _set_canonical_alias(self, content): def _set_canonical_alias(self, content) -> None:
"""Configure the canonical alias state on the room.""" """Configure the canonical alias state on the room."""
self.helper.send_state( self.helper.send_state(
self.room_id, self.room_id,
@ -334,7 +340,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
) )
) )
def test_remove_alias(self): def test_remove_alias(self) -> None:
"""Removing an alias that is the canonical alias should remove it there too.""" """Removing an alias that is the canonical alias should remove it there too."""
# Set this new alias as the canonical alias for this room # Set this new alias as the canonical alias for this room
self._set_canonical_alias( self._set_canonical_alias(
@ -356,7 +362,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
self.assertNotIn("alias", data["content"]) self.assertNotIn("alias", data["content"])
self.assertNotIn("alt_aliases", data["content"]) self.assertNotIn("alt_aliases", data["content"])
def test_remove_other_alias(self): def test_remove_other_alias(self) -> None:
"""Removing an alias listed as in alt_aliases should remove it there too.""" """Removing an alias listed as in alt_aliases should remove it there too."""
# Create a second alias. # Create a second alias.
other_test_alias = "#test2:test" other_test_alias = "#test2:test"
@ -393,7 +399,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
servlets = [directory.register_servlets, room.register_servlets] servlets = [directory.register_servlets, room.register_servlets]
def default_config(self): def default_config(self) -> Dict[str, Any]:
config = super().default_config() config = super().default_config()
# Add custom alias creation rules to the config. # Add custom alias creation rules to the config.
@ -403,7 +409,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
return config return config
def test_denied(self): def test_denied(self) -> None:
room_id = self.helper.create_room_as(self.user_id) room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request( channel = self.make_request(
@ -413,7 +419,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
) )
self.assertEqual(403, channel.code, channel.result) self.assertEqual(403, channel.code, channel.result)
def test_allowed(self): def test_allowed(self) -> None:
room_id = self.helper.create_room_as(self.user_id) room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request( channel = self.make_request(
@ -423,7 +429,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
) )
self.assertEqual(200, channel.code, channel.result) self.assertEqual(200, channel.code, channel.result)
def test_denied_during_creation(self): def test_denied_during_creation(self) -> None:
"""A room alias that is not allowed should be rejected during creation.""" """A room alias that is not allowed should be rejected during creation."""
# Invalid room alias. # Invalid room alias.
self.helper.create_room_as( self.helper.create_room_as(
@ -432,7 +438,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
extra_content={"room_alias_name": "foo"}, extra_content={"room_alias_name": "foo"},
) )
def test_allowed_during_creation(self): def test_allowed_during_creation(self) -> None:
"""A valid room alias should be allowed during creation.""" """A valid room alias should be allowed during creation."""
room_id = self.helper.create_room_as( room_id = self.helper.create_room_as(
self.user_id, self.user_id,
@ -459,7 +465,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
data = {"room_alias_name": "unofficial_test"} data = {"room_alias_name": "unofficial_test"}
allowed_localpart = "allowed" allowed_localpart = "allowed"
def default_config(self): def default_config(self) -> Dict[str, Any]:
config = super().default_config() config = super().default_config()
# Add custom room list publication rules to the config. # Add custom room list publication rules to the config.
@ -474,7 +480,9 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
return config return config
def prepare(self, reactor, clock, hs): def prepare(
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
) -> HomeServer:
self.allowed_user_id = self.register_user(self.allowed_localpart, "pass") self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
self.allowed_access_token = self.login(self.allowed_localpart, "pass") self.allowed_access_token = self.login(self.allowed_localpart, "pass")
@ -483,7 +491,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
return hs return hs
def test_denied_without_publication_permission(self): def test_denied_without_publication_permission(self) -> None:
""" """
Try to create a room, register an alias for it, and publish it, Try to create a room, register an alias for it, and publish it,
as a user without permission to publish rooms. as a user without permission to publish rooms.
@ -497,7 +505,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=403, expect_code=403,
) )
def test_allowed_when_creating_private_room(self): def test_allowed_when_creating_private_room(self) -> None:
""" """
Try to create a room, register an alias for it, and NOT publish it, Try to create a room, register an alias for it, and NOT publish it,
as a user without permission to publish rooms. as a user without permission to publish rooms.
@ -511,7 +519,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=200, expect_code=200,
) )
def test_allowed_with_publication_permission(self): def test_allowed_with_publication_permission(self) -> None:
""" """
Try to create a room, register an alias for it, and publish it, Try to create a room, register an alias for it, and publish it,
as a user WITH permission to publish rooms. as a user WITH permission to publish rooms.
@ -525,7 +533,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=200, expect_code=200,
) )
def test_denied_publication_with_invalid_alias(self): def test_denied_publication_with_invalid_alias(self) -> None:
""" """
Try to create a room, register an alias for it, and publish it, Try to create a room, register an alias for it, and publish it,
as a user WITH permission to publish rooms. as a user WITH permission to publish rooms.
@ -538,7 +546,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=403, expect_code=403,
) )
def test_can_create_as_private_room_after_rejection(self): def test_can_create_as_private_room_after_rejection(self) -> None:
""" """
After failing to publish a room with an alias as a user without publish permission, After failing to publish a room with an alias as a user without publish permission,
retry as the same user, but without publishing the room. retry as the same user, but without publishing the room.
@ -549,7 +557,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
self.test_denied_without_publication_permission() self.test_denied_without_publication_permission()
self.test_allowed_when_creating_private_room() self.test_allowed_when_creating_private_room()
def test_can_create_with_permission_after_rejection(self): def test_can_create_with_permission_after_rejection(self) -> None:
""" """
After failing to publish a room with an alias as a user without publish permission, After failing to publish a room with an alias as a user without publish permission,
retry as someone with permission, using the same alias. retry as someone with permission, using the same alias.
@ -566,7 +574,9 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
servlets = [directory.register_servlets, room.register_servlets] servlets = [directory.register_servlets, room.register_servlets]
def prepare(self, reactor, clock, hs): def prepare(
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
) -> HomeServer:
room_id = self.helper.create_room_as(self.user_id) room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request( channel = self.make_request(
@ -579,7 +589,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
return hs return hs
def test_disabling_room_list(self): def test_disabling_room_list(self) -> None:
self.room_list_handler.enable_room_list_search = True self.room_list_handler.enable_room_list_search = True
self.directory_handler.enable_room_list_search = True self.directory_handler.enable_room_list_search = True

View file

@ -20,33 +20,37 @@ from parameterized import parameterized
from signedjson import key as key, sign as sign from signedjson import key as key, sign as sign
from twisted.internet import defer from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(federation_client=mock.Mock()) return self.setup_test_homeserver(federation_client=mock.Mock())
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_e2e_keys_handler() self.handler = hs.get_e2e_keys_handler()
self.store = self.hs.get_datastores().main self.store = self.hs.get_datastores().main
def test_query_local_devices_no_devices(self): def test_query_local_devices_no_devices(self) -> None:
"""If the user has no devices, we expect an empty list.""" """If the user has no devices, we expect an empty list."""
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
res = self.get_success(self.handler.query_local_devices({local_user: None})) res = self.get_success(self.handler.query_local_devices({local_user: None}))
self.assertDictEqual(res, {local_user: {}}) self.assertDictEqual(res, {local_user: {}})
def test_reupload_one_time_keys(self): def test_reupload_one_time_keys(self) -> None:
"""we should be able to re-upload the same keys""" """we should be able to re-upload the same keys"""
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
device_id = "xyz" device_id = "xyz"
keys = { keys: JsonDict = {
"alg1:k1": "key1", "alg1:k1": "key1",
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}}, "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
"alg2:k3": {"key": "key3"}, "alg2:k3": {"key": "key3"},
@ -74,7 +78,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}} res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
) )
def test_change_one_time_keys(self): def test_change_one_time_keys(self) -> None:
"""attempts to change one-time-keys should be rejected""" """attempts to change one-time-keys should be rejected"""
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
@ -134,7 +138,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
SynapseError, SynapseError,
) )
def test_claim_one_time_key(self): def test_claim_one_time_key(self) -> None:
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
device_id = "xyz" device_id = "xyz"
keys = {"alg1:k1": "key1"} keys = {"alg1:k1": "key1"}
@ -161,7 +165,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
}, },
) )
def test_fallback_key(self): def test_fallback_key(self) -> None:
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
device_id = "xyz" device_id = "xyz"
fallback_key = {"alg1:k1": "fallback_key1"} fallback_key = {"alg1:k1": "fallback_key1"}
@ -294,7 +298,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}}, {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
) )
def test_replace_master_key(self): def test_replace_master_key(self) -> None:
"""uploading a new signing key should make the old signing key unavailable""" """uploading a new signing key should make the old signing key unavailable"""
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
keys1 = { keys1 = {
@ -328,7 +332,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
) )
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
def test_reupload_signatures(self): def test_reupload_signatures(self) -> None:
"""re-uploading a signature should not fail""" """re-uploading a signature should not fail"""
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
keys1 = { keys1 = {
@ -433,7 +437,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1) self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2) self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
def test_self_signing_key_doesnt_show_up_as_device(self): def test_self_signing_key_doesnt_show_up_as_device(self) -> None:
"""signing keys should be hidden when fetching a user's devices""" """signing keys should be hidden when fetching a user's devices"""
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
keys1 = { keys1 = {
@ -462,7 +466,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
res = self.get_success(self.handler.query_local_devices({local_user: None})) res = self.get_success(self.handler.query_local_devices({local_user: None}))
self.assertDictEqual(res, {local_user: {}}) self.assertDictEqual(res, {local_user: {}})
def test_upload_signatures(self): def test_upload_signatures(self) -> None:
"""should check signatures that are uploaded""" """should check signatures that are uploaded"""
# set up a user with cross-signing keys and a device. This user will # set up a user with cross-signing keys and a device. This user will
# try uploading signatures # try uploading signatures
@ -686,7 +690,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey], other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
) )
def test_query_devices_remote_no_sync(self): def test_query_devices_remote_no_sync(self) -> None:
"""Tests that querying keys for a remote user that we don't share a room """Tests that querying keys for a remote user that we don't share a room
with returns the cross signing keys correctly. with returns the cross signing keys correctly.
""" """
@ -759,7 +763,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
}, },
) )
def test_query_devices_remote_sync(self): def test_query_devices_remote_sync(self) -> None:
"""Tests that querying keys for a remote user that we share a room with, """Tests that querying keys for a remote user that we share a room with,
but haven't yet fetched the keys for, returns the cross signing keys but haven't yet fetched the keys for, returns the cross signing keys
correctly. correctly.
@ -845,7 +849,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
(["device_1", "device_2"],), (["device_1", "device_2"],),
] ]
) )
def test_query_all_devices_caches_result(self, device_ids: Iterable[str]): def test_query_all_devices_caches_result(self, device_ids: Iterable[str]) -> None:
"""Test that requests for all of a remote user's devices are cached. """Test that requests for all of a remote user's devices are cached.
We do this by asserting that only one call over federation was made, and that We do this by asserting that only one call over federation was made, and that
@ -853,7 +857,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
""" """
local_user_id = "@test:test" local_user_id = "@test:test"
remote_user_id = "@test:other" remote_user_id = "@test:other"
request_body = {"device_keys": {remote_user_id: []}} request_body: JsonDict = {"device_keys": {remote_user_id: []}}
response_devices = [ response_devices = [
{ {

View file

@ -13,14 +13,18 @@
# limitations under the License. # limitations under the License.
import json import json
import os import os
from typing import Any, Dict
from unittest.mock import ANY, Mock, patch from unittest.mock import ANY, Mock, patch
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
import pymacaroons import pymacaroons
from twisted.test.proto_helpers import MemoryReactor
from synapse.handlers.sso import MappingException from synapse.handlers.sso import MappingException
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import UserID from synapse.types import JsonDict, UserID
from synapse.util import Clock
from synapse.util.macaroons import get_value_from_macaroon from synapse.util.macaroons import get_value_from_macaroon
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
@ -98,7 +102,7 @@ class TestMappingProviderFailures(TestMappingProvider):
} }
async def get_json(url): async def get_json(url: str) -> JsonDict:
# Mock get_json calls to handle jwks & oidc discovery endpoints # Mock get_json calls to handle jwks & oidc discovery endpoints
if url == WELL_KNOWN: if url == WELL_KNOWN:
# Minimal discovery document, as defined in OpenID.Discovery # Minimal discovery document, as defined in OpenID.Discovery
@ -116,6 +120,8 @@ async def get_json(url):
elif url == JWKS_URI: elif url == JWKS_URI:
return {"keys": []} return {"keys": []}
return {}
def _key_file_path() -> str: def _key_file_path() -> str:
"""path to a file containing the private half of a test key""" """path to a file containing the private half of a test key"""
@ -147,12 +153,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
if not HAS_OIDC: if not HAS_OIDC:
skip = "requires OIDC" skip = "requires OIDC"
def default_config(self): def default_config(self) -> Dict[str, Any]:
config = super().default_config() config = super().default_config()
config["public_baseurl"] = BASE_URL config["public_baseurl"] = BASE_URL
return config return config
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.http_client = Mock(spec=["get_json"]) self.http_client = Mock(spec=["get_json"])
self.http_client.get_json.side_effect = get_json self.http_client.get_json.side_effect = get_json
self.http_client.user_agent = b"Synapse Test" self.http_client.user_agent = b"Synapse Test"
@ -164,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
sso_handler = hs.get_sso_handler() sso_handler = hs.get_sso_handler()
# Mock the render error method. # Mock the render error method.
self.render_error = Mock(return_value=None) self.render_error = Mock(return_value=None)
sso_handler.render_error = self.render_error sso_handler.render_error = self.render_error # type: ignore[assignment]
# Reduce the number of attempts when generating MXIDs. # Reduce the number of attempts when generating MXIDs.
sso_handler._MAP_USERNAME_RETRIES = 3 sso_handler._MAP_USERNAME_RETRIES = 3
@ -193,14 +199,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
return args return args
@override_config({"oidc_config": DEFAULT_CONFIG}) @override_config({"oidc_config": DEFAULT_CONFIG})
def test_config(self): def test_config(self) -> None:
"""Basic config correctly sets up the callback URL and client auth correctly.""" """Basic config correctly sets up the callback URL and client auth correctly."""
self.assertEqual(self.provider._callback_url, CALLBACK_URL) self.assertEqual(self.provider._callback_url, CALLBACK_URL)
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID) self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET) self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
@override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}}) @override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
def test_discovery(self): def test_discovery(self) -> None:
"""The handler should discover the endpoints from OIDC discovery document.""" """The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid # This would throw if some metadata were invalid
metadata = self.get_success(self.provider.load_metadata()) metadata = self.get_success(self.provider.load_metadata())
@ -219,13 +225,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.http_client.get_json.assert_not_called() self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
def test_no_discovery(self): def test_no_discovery(self) -> None:
"""When discovery is disabled, it should not try to load from discovery document.""" """When discovery is disabled, it should not try to load from discovery document."""
self.get_success(self.provider.load_metadata()) self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called() self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
def test_load_jwks(self): def test_load_jwks(self) -> None:
"""JWKS loading is done once (then cached) if used.""" """JWKS loading is done once (then cached) if used."""
jwks = self.get_success(self.provider.load_jwks()) jwks = self.get_success(self.provider.load_jwks())
self.http_client.get_json.assert_called_once_with(JWKS_URI) self.http_client.get_json.assert_called_once_with(JWKS_URI)
@ -253,7 +259,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_failure(self.provider.load_jwks(force=True), RuntimeError) self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
@override_config({"oidc_config": DEFAULT_CONFIG}) @override_config({"oidc_config": DEFAULT_CONFIG})
def test_validate_config(self): def test_validate_config(self) -> None:
"""Provider metadatas are extensively validated.""" """Provider metadatas are extensively validated."""
h = self.provider h = self.provider
@ -336,14 +342,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
force_load_metadata() force_load_metadata()
@override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}}) @override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
def test_skip_verification(self): def test_skip_verification(self) -> None:
"""Provider metadata validation can be disabled by config.""" """Provider metadata validation can be disabled by config."""
with self.metadata_edit({"issuer": "http://insecure"}): with self.metadata_edit({"issuer": "http://insecure"}):
# This should not throw # This should not throw
get_awaitable_result(self.provider.load_metadata()) get_awaitable_result(self.provider.load_metadata())
@override_config({"oidc_config": DEFAULT_CONFIG}) @override_config({"oidc_config": DEFAULT_CONFIG})
def test_redirect_request(self): def test_redirect_request(self) -> None:
"""The redirect request has the right arguments & generates a valid session cookie.""" """The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["cookies"]) req = Mock(spec=["cookies"])
req.cookies = [] req.cookies = []
@ -387,7 +393,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(redirect, "http://client/redirect") self.assertEqual(redirect, "http://client/redirect")
@override_config({"oidc_config": DEFAULT_CONFIG}) @override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_error(self): def test_callback_error(self) -> None:
"""Errors from the provider returned in the callback are displayed.""" """Errors from the provider returned in the callback are displayed."""
request = Mock(args={}) request = Mock(args={})
request.args[b"error"] = [b"invalid_client"] request.args[b"error"] = [b"invalid_client"]
@ -399,7 +405,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("invalid_client", "some description") self.assertRenderedError("invalid_client", "some description")
@override_config({"oidc_config": DEFAULT_CONFIG}) @override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback(self): def test_callback(self) -> None:
"""Code callback works and display errors if something went wrong. """Code callback works and display errors if something went wrong.
A lot of scenarios are tested here: A lot of scenarios are tested here:
@ -428,9 +434,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": username, "username": username,
} }
expected_user_id = "@%s:%s" % (username, self.hs.hostname) expected_user_id = "@%s:%s" % (username, self.hs.hostname)
self.provider._exchange_code = simple_async_mock(return_value=token) self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
self.provider._parse_id_token = simple_async_mock(return_value=userinfo) self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock()
@ -468,7 +474,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("mapping_error") self.assertRenderedError("mapping_error")
# Handle ID token errors # Handle ID token errors
self.provider._parse_id_token = simple_async_mock(raises=Exception()) self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment]
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token") self.assertRenderedError("invalid_token")
@ -483,7 +489,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"type": "bearer", "type": "bearer",
"access_token": "access_token", "access_token": "access_token",
} }
self.provider._exchange_code = simple_async_mock(return_value=token) self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with( auth_handler.complete_sso_login.assert_called_once_with(
@ -510,8 +516,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
id_token = { id_token = {
"sid": "abcdefgh", "sid": "abcdefgh",
} }
self.provider._parse_id_token = simple_async_mock(return_value=id_token) self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment]
self.provider._exchange_code = simple_async_mock(return_value=token) self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
auth_handler.complete_sso_login.reset_mock() auth_handler.complete_sso_login.reset_mock()
self.provider._fetch_userinfo.reset_mock() self.provider._fetch_userinfo.reset_mock()
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
@ -531,21 +537,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.render_error.assert_not_called() self.render_error.assert_not_called()
# Handle userinfo fetching error # Handle userinfo fetching error
self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment]
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error") self.assertRenderedError("fetch_error")
# Handle code exchange failure # Handle code exchange failure
from synapse.handlers.oidc import OidcError from synapse.handlers.oidc import OidcError
self.provider._exchange_code = simple_async_mock( self.provider._exchange_code = simple_async_mock( # type: ignore[assignment]
raises=OidcError("invalid_request") raises=OidcError("invalid_request")
) )
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request") self.assertRenderedError("invalid_request")
@override_config({"oidc_config": DEFAULT_CONFIG}) @override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_session(self): def test_callback_session(self) -> None:
"""The callback verifies the session presence and validity""" """The callback verifies the session presence and validity"""
request = Mock(spec=["args", "getCookie", "cookies"]) request = Mock(spec=["args", "getCookie", "cookies"])
@ -590,7 +596,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config( @override_config(
{"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}} {"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
) )
def test_exchange_code(self): def test_exchange_code(self) -> None:
"""Code exchange behaves correctly and handles various error scenarios.""" """Code exchange behaves correctly and handles various error scenarios."""
token = {"type": "bearer"} token = {"type": "bearer"}
token_json = json.dumps(token).encode("utf-8") token_json = json.dumps(token).encode("utf-8")
@ -686,7 +692,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
} }
} }
) )
def test_exchange_code_jwt_key(self): def test_exchange_code_jwt_key(self) -> None:
"""Test that code exchange works with a JWK client secret.""" """Test that code exchange works with a JWK client secret."""
from authlib.jose import jwt from authlib.jose import jwt
@ -741,7 +747,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
} }
} }
) )
def test_exchange_code_no_auth(self): def test_exchange_code_no_auth(self) -> None:
"""Test that code exchange works with no client secret.""" """Test that code exchange works with no client secret."""
token = {"type": "bearer"} token = {"type": "bearer"}
self.http_client.request = simple_async_mock( self.http_client.request = simple_async_mock(
@ -776,7 +782,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
} }
} }
) )
def test_extra_attributes(self): def test_extra_attributes(self) -> None:
""" """
Login while using a mapping provider that implements get_extra_attributes. Login while using a mapping provider that implements get_extra_attributes.
""" """
@ -790,8 +796,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "foo", "username": "foo",
"phone": "1234567", "phone": "1234567",
} }
self.provider._exchange_code = simple_async_mock(return_value=token) self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
self.provider._parse_id_token = simple_async_mock(return_value=userinfo) self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock()
@ -817,12 +823,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
) )
@override_config({"oidc_config": DEFAULT_CONFIG}) @override_config({"oidc_config": DEFAULT_CONFIG})
def test_map_userinfo_to_user(self): def test_map_userinfo_to_user(self) -> None:
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock()
userinfo = { userinfo: dict = {
"sub": "test_user", "sub": "test_user",
"username": "test_user", "username": "test_user",
} }
@ -870,7 +876,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
) )
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}}) @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
def test_map_userinfo_to_existing_user(self): def test_map_userinfo_to_existing_user(self) -> None:
"""Existing users can log in with OpenID Connect when allow_existing_users is True.""" """Existing users can log in with OpenID Connect when allow_existing_users is True."""
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
user = UserID.from_string("@test_user:test") user = UserID.from_string("@test_user:test")
@ -974,7 +980,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
) )
@override_config({"oidc_config": DEFAULT_CONFIG}) @override_config({"oidc_config": DEFAULT_CONFIG})
def test_map_userinfo_to_invalid_localpart(self): def test_map_userinfo_to_invalid_localpart(self) -> None:
"""If the mapping provider generates an invalid localpart it should be rejected.""" """If the mapping provider generates an invalid localpart it should be rejected."""
self.get_success( self.get_success(
_make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"}) _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
@ -991,7 +997,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
} }
} }
) )
def test_map_userinfo_to_user_retries(self): def test_map_userinfo_to_user_retries(self) -> None:
"""The mapping provider can retry generating an MXID if the MXID is already in use.""" """The mapping provider can retry generating an MXID if the MXID is already in use."""
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock()
@ -1039,7 +1045,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
) )
@override_config({"oidc_config": DEFAULT_CONFIG}) @override_config({"oidc_config": DEFAULT_CONFIG})
def test_empty_localpart(self): def test_empty_localpart(self) -> None:
"""Attempts to map onto an empty localpart should be rejected.""" """Attempts to map onto an empty localpart should be rejected."""
userinfo = { userinfo = {
"sub": "tester", "sub": "tester",
@ -1058,7 +1064,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
} }
} }
) )
def test_null_localpart(self): def test_null_localpart(self) -> None:
"""Mapping onto a null localpart via an empty OIDC attribute should be rejected""" """Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
userinfo = { userinfo = {
"sub": "tester", "sub": "tester",
@ -1075,7 +1081,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
} }
} }
) )
def test_attribute_requirements(self): def test_attribute_requirements(self) -> None:
"""The required attributes must be met from the OIDC userinfo response.""" """The required attributes must be met from the OIDC userinfo response."""
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock()
@ -1115,7 +1121,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
} }
} }
) )
def test_attribute_requirements_contains(self): def test_attribute_requirements_contains(self) -> None:
"""Test that auth succeeds if userinfo attribute CONTAINS required value""" """Test that auth succeeds if userinfo attribute CONTAINS required value"""
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock()
@ -1146,7 +1152,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
} }
} }
) )
def test_attribute_requirements_mismatch(self): def test_attribute_requirements_mismatch(self) -> None:
""" """
Test that auth fails if attributes exist but don't match, Test that auth fails if attributes exist but don't match,
or are non-string values. or are non-string values.
@ -1154,7 +1160,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() auth_handler.complete_sso_login = simple_async_mock()
# userinfo with "test": "not_foobar" attribute should fail # userinfo with "test": "not_foobar" attribute should fail
userinfo = { userinfo: dict = {
"sub": "tester", "sub": "tester",
"username": "tester", "username": "tester",
"test": "not_foobar", "test": "not_foobar",
@ -1248,9 +1254,9 @@ async def _make_callback_with_userinfo(
handler = hs.get_oidc_handler() handler = hs.get_oidc_handler()
provider = handler._providers["oidc"] provider = handler._providers["oidc"]
provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment]
provider._parse_id_token = simple_async_mock(return_value=userinfo) provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
provider._fetch_userinfo = simple_async_mock(return_value=userinfo) provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
state = "state" state = "state"
session = handler._token_generator.generate_oidc_session_token( session = handler._token_generator.generate_oidc_session_token(

View file

@ -11,14 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict from typing import Any, Awaitable, Callable, Dict
from unittest.mock import Mock from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
import synapse.types import synapse.types
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.rest import admin from synapse.rest import admin
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import UserID from synapse.types import JsonDict, UserID
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
@ -29,13 +32,15 @@ class ProfileTestCase(unittest.HomeserverTestCase):
servlets = [admin.register_servlets] servlets = [admin.register_servlets]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.mock_federation = Mock() self.mock_federation = Mock()
self.mock_registry = Mock() self.mock_registry = Mock()
self.query_handlers = {} self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
def register_query_handler(query_type, handler): def register_query_handler(
query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
) -> None:
self.query_handlers[query_type] = handler self.query_handlers[query_type] = handler
self.mock_registry.register_query_handler = register_query_handler self.mock_registry.register_query_handler = register_query_handler
@ -47,7 +52,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
) )
return hs return hs
def prepare(self, reactor, clock, hs: HomeServer): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.frank = UserID.from_string("@1234abcd:test") self.frank = UserID.from_string("@1234abcd:test")
@ -58,7 +63,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.handler = hs.get_profile_handler() self.handler = hs.get_profile_handler()
def test_get_my_name(self): def test_get_my_name(self) -> None:
self.get_success( self.get_success(
self.store.set_profile_displayname(self.frank.localpart, "Frank") self.store.set_profile_displayname(self.frank.localpart, "Frank")
) )
@ -67,7 +72,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual("Frank", displayname) self.assertEqual("Frank", displayname)
def test_set_my_name(self): def test_set_my_name(self) -> None:
self.get_success( self.get_success(
self.handler.set_displayname( self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr." self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
@ -110,7 +115,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.get_profile_displayname(self.frank.localpart)) self.get_success(self.store.get_profile_displayname(self.frank.localpart))
) )
def test_set_my_name_if_disabled(self): def test_set_my_name_if_disabled(self) -> None:
self.hs.config.registration.enable_set_displayname = False self.hs.config.registration.enable_set_displayname = False
# Setting displayname for the first time is allowed # Setting displayname for the first time is allowed
@ -135,7 +140,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
SynapseError, SynapseError,
) )
def test_set_my_name_noauth(self): def test_set_my_name_noauth(self) -> None:
self.get_failure( self.get_failure(
self.handler.set_displayname( self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.bob), "Frank Jr." self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
@ -143,7 +148,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
AuthError, AuthError,
) )
def test_get_other_name(self): def test_get_other_name(self) -> None:
self.mock_federation.make_query.return_value = make_awaitable( self.mock_federation.make_query.return_value = make_awaitable(
{"displayname": "Alice"} {"displayname": "Alice"}
) )
@ -158,7 +163,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
ignore_backoff=True, ignore_backoff=True,
) )
def test_incoming_fed_query(self): def test_incoming_fed_query(self) -> None:
self.get_success(self.store.create_profile("caroline")) self.get_success(self.store.create_profile("caroline"))
self.get_success(self.store.set_profile_displayname("caroline", "Caroline")) self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
@ -174,7 +179,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual({"displayname": "Caroline"}, response) self.assertEqual({"displayname": "Caroline"}, response)
def test_get_my_avatar(self): def test_get_my_avatar(self) -> None:
self.get_success( self.get_success(
self.store.set_profile_avatar_url( self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png" self.frank.localpart, "http://my.server/me.png"
@ -184,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual("http://my.server/me.png", avatar_url) self.assertEqual("http://my.server/me.png", avatar_url)
def test_set_my_avatar(self): def test_set_my_avatar(self) -> None:
self.get_success( self.get_success(
self.handler.set_avatar_url( self.handler.set_avatar_url(
self.frank, self.frank,
@ -225,7 +230,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
) )
def test_set_my_avatar_if_disabled(self): def test_set_my_avatar_if_disabled(self) -> None:
self.hs.config.registration.enable_set_avatar_url = False self.hs.config.registration.enable_set_avatar_url = False
# Setting displayname for the first time is allowed # Setting displayname for the first time is allowed
@ -250,7 +255,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
SynapseError, SynapseError,
) )
def test_avatar_constraints_no_config(self): def test_avatar_constraints_no_config(self) -> None:
"""Tests that the method to check an avatar against configured constraints skips """Tests that the method to check an avatar against configured constraints skips
all of its check if no constraint is configured. all of its check if no constraint is configured.
""" """
@ -263,7 +268,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertTrue(res) self.assertTrue(res)
@unittest.override_config({"max_avatar_size": 50}) @unittest.override_config({"max_avatar_size": 50})
def test_avatar_constraints_missing(self): def test_avatar_constraints_missing(self) -> None:
"""Tests that an avatar isn't allowed if the file at the given MXC URI couldn't """Tests that an avatar isn't allowed if the file at the given MXC URI couldn't
be found. be found.
""" """
@ -273,7 +278,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertFalse(res) self.assertFalse(res)
@unittest.override_config({"max_avatar_size": 50}) @unittest.override_config({"max_avatar_size": 50})
def test_avatar_constraints_file_size(self): def test_avatar_constraints_file_size(self) -> None:
"""Tests that a file that's above the allowed file size is forbidden but one """Tests that a file that's above the allowed file size is forbidden but one
that's below it is allowed. that's below it is allowed.
""" """
@ -295,7 +300,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertFalse(res) self.assertFalse(res)
@unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]}) @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
def test_avatar_constraint_mime_type(self): def test_avatar_constraint_mime_type(self) -> None:
"""Tests that a file with an unauthorised MIME type is forbidden but one with """Tests that a file with an unauthorised MIME type is forbidden but one with
an authorised content type is allowed. an authorised content type is allowed.
""" """

View file

@ -12,12 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional from typing import Any, Dict, Optional
from unittest.mock import Mock from unittest.mock import Mock
import attr import attr
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import RedirectException from synapse.api.errors import RedirectException
from synapse.server import HomeServer
from synapse.util import Clock
from tests.test_utils import simple_async_mock from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config from tests.unittest import HomeserverTestCase, override_config
@ -81,10 +85,10 @@ class TestRedirectMappingProvider(TestMappingProvider):
class SamlHandlerTestCase(HomeserverTestCase): class SamlHandlerTestCase(HomeserverTestCase):
def default_config(self): def default_config(self) -> Dict[str, Any]:
config = super().default_config() config = super().default_config()
config["public_baseurl"] = BASE_URL config["public_baseurl"] = BASE_URL
saml_config = { saml_config: Dict[str, Any] = {
"sp_config": {"metadata": {}}, "sp_config": {"metadata": {}},
# Disable grandfathering. # Disable grandfathering.
"grandfathered_mxid_source_attribute": None, "grandfathered_mxid_source_attribute": None,
@ -98,7 +102,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
return config return config
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver() hs = self.setup_test_homeserver()
self.handler = hs.get_saml_handler() self.handler = hs.get_saml_handler()
@ -114,7 +118,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
elif not has_xmlsec1: elif not has_xmlsec1:
skip = "Requires xmlsec1" skip = "Requires xmlsec1"
def test_map_saml_response_to_user(self): def test_map_saml_response_to_user(self) -> None:
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly.""" """Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
# stub out the auth handler # stub out the auth handler
@ -140,7 +144,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
) )
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
def test_map_saml_response_to_existing_user(self): def test_map_saml_response_to_existing_user(self) -> None:
"""Existing users can log in with SAML account.""" """Existing users can log in with SAML account."""
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
self.get_success( self.get_success(
@ -186,7 +190,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
auth_provider_session_id=None, auth_provider_session_id=None,
) )
def test_map_saml_response_to_invalid_localpart(self): def test_map_saml_response_to_invalid_localpart(self) -> None:
"""If the mapping provider generates an invalid localpart it should be rejected.""" """If the mapping provider generates an invalid localpart it should be rejected."""
# stub out the auth handler # stub out the auth handler
@ -207,7 +211,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
) )
auth_handler.complete_sso_login.assert_not_called() auth_handler.complete_sso_login.assert_not_called()
def test_map_saml_response_to_user_retries(self): def test_map_saml_response_to_user_retries(self) -> None:
"""The mapping provider can retry generating an MXID if the MXID is already in use.""" """The mapping provider can retry generating an MXID if the MXID is already in use."""
# stub out the auth handler and error renderer # stub out the auth handler and error renderer
@ -271,7 +275,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
} }
} }
) )
def test_map_saml_response_redirect(self): def test_map_saml_response_redirect(self) -> None:
"""Test a mapping provider that raises a RedirectException""" """Test a mapping provider that raises a RedirectException"""
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
@ -292,7 +296,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
}, },
} }
) )
def test_attribute_requirements(self): def test_attribute_requirements(self) -> None:
"""The required attributes must be met from the SAML response.""" """The required attributes must be met from the SAML response."""
# stub out the auth handler # stub out the auth handler