diff --git a/changelog.d/11066.misc b/changelog.d/11066.misc new file mode 100644 index 0000000000..1e337bee54 --- /dev/null +++ b/changelog.d/11066.misc @@ -0,0 +1 @@ +Add type hints to `synapse.events`. diff --git a/mypy.ini b/mypy.ini index 93757cd95d..2cdd552f46 100644 --- a/mypy.ini +++ b/mypy.ini @@ -22,8 +22,11 @@ files = synapse/crypto, synapse/event_auth.py, synapse/events/builder.py, + synapse/events/presence_router.py, + synapse/events/snapshot.py, synapse/events/spamcheck.py, synapse/events/third_party_rules.py, + synapse/events/utils.py, synapse/events/validator.py, synapse/federation, synapse/groups, @@ -96,6 +99,9 @@ files = tests/util/test_itertools.py, tests/util/test_stream_change_cache.py +[mypy-synapse.events.*] +disallow_untyped_defs = True + [mypy-synapse.handlers.*] disallow_untyped_defs = True diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 50f2a4c1f4..4f409f31e1 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -90,13 +90,13 @@ class EventBuilder: ) @property - def state_key(self): + def state_key(self) -> str: if self._state_key is not None: return self._state_key raise AttributeError("state_key") - def is_state(self): + def is_state(self) -> bool: return self._state_key is not None async def build( diff --git a/synapse/events/presence_router.py b/synapse/events/presence_router.py index 68b8b19024..a58f313e8b 100644 --- a/synapse/events/presence_router.py +++ b/synapse/events/presence_router.py @@ -14,6 +14,7 @@ import logging from typing import ( TYPE_CHECKING, + Any, Awaitable, Callable, Dict, @@ -33,14 +34,13 @@ if TYPE_CHECKING: GET_USERS_FOR_STATES_CALLBACK = Callable[ [Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]] ] -GET_INTERESTED_USERS_CALLBACK = Callable[ - [str], Awaitable[Union[Set[str], "PresenceRouter.ALL_USERS"]] -] +# This must either return a set of strings or the constant PresenceRouter.ALL_USERS. +GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[Set[str], str]]] logger = logging.getLogger(__name__) -def load_legacy_presence_router(hs: "HomeServer"): +def load_legacy_presence_router(hs: "HomeServer") -> None: """Wrapper that loads a presence router module configured using the old configuration, and registers the hooks they implement. """ @@ -69,9 +69,10 @@ def load_legacy_presence_router(hs: "HomeServer"): if f is None: return None - def run(*args, **kwargs): - # mypy doesn't do well across function boundaries so we need to tell it - # f is definitely not None. + def run(*args: Any, **kwargs: Any) -> Awaitable: + # Assertion required because mypy can't prove we won't change `f` + # back to `None`. See + # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions assert f is not None return maybe_awaitable(f(*args, **kwargs)) @@ -104,7 +105,7 @@ class PresenceRouter: self, get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None, get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None, - ): + ) -> None: # PresenceRouter modules are required to implement both of these methods # or neither of them as they are assumed to act in a complementary manner paired_methods = [get_users_for_states, get_interested_users] @@ -142,7 +143,7 @@ class PresenceRouter: # Don't include any extra destinations for presence updates return {} - users_for_states = {} + users_for_states: Dict[str, Set[UserPresenceState]] = {} # run all the callbacks for get_users_for_states and combine the results for callback in self._get_users_for_states_callbacks: try: @@ -171,7 +172,7 @@ class PresenceRouter: return users_for_states - async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]: + async def get_interested_users(self, user_id: str) -> Union[Set[str], str]: """ Retrieve a list of users that `user_id` is interested in receiving the presence of. This will be in addition to those they share a room with. diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 5ba01eeef9..d7527008c4 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -11,17 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import attr from frozendict import frozendict +from twisted.internet.defer import Deferred + from synapse.appservice import ApplicationService from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable, run_in_background -from synapse.types import StateMap +from synapse.types import JsonDict, StateMap if TYPE_CHECKING: + from synapse.storage import Storage from synapse.storage.databases.main import DataStore @@ -112,13 +115,13 @@ class EventContext: @staticmethod def with_state( - state_group, - state_group_before_event, - current_state_ids, - prev_state_ids, - prev_group=None, - delta_ids=None, - ): + state_group: Optional[int], + state_group_before_event: Optional[int], + current_state_ids: Optional[StateMap[str]], + prev_state_ids: Optional[StateMap[str]], + prev_group: Optional[int] = None, + delta_ids: Optional[StateMap[str]] = None, + ) -> "EventContext": return EventContext( current_state_ids=current_state_ids, prev_state_ids=prev_state_ids, @@ -129,22 +132,22 @@ class EventContext: ) @staticmethod - def for_outlier(): + def for_outlier() -> "EventContext": """Return an EventContext instance suitable for persisting an outlier event""" return EventContext( current_state_ids={}, prev_state_ids={}, ) - async def serialize(self, event: EventBase, store: "DataStore") -> dict: + async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict: """Converts self to a type that can be serialized as JSON, and then deserialized by `deserialize` Args: - event (FrozenEvent): The event that this context relates to + event: The event that this context relates to Returns: - dict + The serialized event. """ # We don't serialize the full state dicts, instead they get pulled out @@ -170,17 +173,16 @@ class EventContext: } @staticmethod - def deserialize(storage, input): + def deserialize(storage: "Storage", input: JsonDict) -> "EventContext": """Converts a dict that was produced by `serialize` back into a EventContext. Args: - storage (Storage): Used to convert AS ID to AS object and fetch - state. - input (dict): A dict produced by `serialize` + storage: Used to convert AS ID to AS object and fetch state. + input: A dict produced by `serialize` Returns: - EventContext + The event context. """ context = _AsyncEventContextImpl( # We use the state_group and prev_state_id stuff to pull the @@ -241,22 +243,25 @@ class EventContext: await self._ensure_fetched() return self._current_state_ids - async def get_prev_state_ids(self): + async def get_prev_state_ids(self) -> StateMap[str]: """ Gets the room state map, excluding this event. For a non-state event, this will be the same as get_current_state_ids(). Returns: - dict[(str, str), str]|None: Returns None if state_group - is None, which happens when the associated event is an outlier. - Maps a (type, state_key) to the event ID of the state event matching - this tuple. + Returns {} if state_group is None, which happens when the associated + event is an outlier. + + Maps a (type, state_key) to the event ID of the state event matching + this tuple. """ await self._ensure_fetched() + # There *should* be previous state IDs now. + assert self._prev_state_ids is not None return self._prev_state_ids - def get_cached_current_state_ids(self): + def get_cached_current_state_ids(self) -> Optional[StateMap[str]]: """Gets the current state IDs if we have them already cached. It is an error to access this for a rejected event, since rejected state should @@ -264,16 +269,17 @@ class EventContext: ``rejected`` is set. Returns: - dict[(str, str), str]|None: Returns None if we haven't cached the - state or if state_group is None, which happens when the associated - event is an outlier. + Returns None if we haven't cached the state or if state_group is None + (which happens when the associated event is an outlier). + + Otherwise, returns the the current state IDs. """ if self.rejected: raise RuntimeError("Attempt to access state_ids of rejected event") return self._current_state_ids - async def _ensure_fetched(self): + async def _ensure_fetched(self) -> None: return None @@ -285,46 +291,46 @@ class _AsyncEventContextImpl(EventContext): Attributes: - _storage (Storage) + _storage - _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have - been calculated. None if we haven't started calculating yet + _fetching_state_deferred: Resolves when *_state_ids have been calculated. + None if we haven't started calculating yet - _event_type (str): The type of the event the context is associated with. + _event_type: The type of the event the context is associated with. - _event_state_key (str): The state_key of the event the context is - associated with. + _event_state_key: The state_key of the event the context is associated with. - _prev_state_id (str|None): If the event associated with the context is - a state event, then `_prev_state_id` is the event_id of the state - that was replaced. + _prev_state_id: If the event associated with the context is a state event, + then `_prev_state_id` is the event_id of the state that was replaced. """ # This needs to have a default as we're inheriting - _storage = attr.ib(default=None) - _prev_state_id = attr.ib(default=None) - _event_type = attr.ib(default=None) - _event_state_key = attr.ib(default=None) - _fetching_state_deferred = attr.ib(default=None) + _storage: "Storage" = attr.ib(default=None) + _prev_state_id: Optional[str] = attr.ib(default=None) + _event_type: str = attr.ib(default=None) + _event_state_key: Optional[str] = attr.ib(default=None) + _fetching_state_deferred: Optional["Deferred[None]"] = attr.ib(default=None) - async def _ensure_fetched(self): + async def _ensure_fetched(self) -> None: if not self._fetching_state_deferred: self._fetching_state_deferred = run_in_background(self._fill_out_state) - return await make_deferred_yieldable(self._fetching_state_deferred) + await make_deferred_yieldable(self._fetching_state_deferred) - async def _fill_out_state(self): + async def _fill_out_state(self) -> None: """Called to populate the _current_state_ids and _prev_state_ids attributes by loading from the database. """ if self.state_group is None: return - self._current_state_ids = await self._storage.state.get_state_ids_for_group( + current_state_ids = await self._storage.state.get_state_ids_for_group( self.state_group ) + # Set this separately so mypy knows current_state_ids is not None. + self._current_state_ids = current_state_ids if self._event_state_key is not None: - self._prev_state_ids = dict(self._current_state_ids) + self._prev_state_ids = dict(current_state_ids) key = (self._event_type, self._event_state_key) if self._prev_state_id: @@ -332,10 +338,12 @@ class _AsyncEventContextImpl(EventContext): else: self._prev_state_ids.pop(key, None) else: - self._prev_state_ids = self._current_state_ids + self._prev_state_ids = current_state_ids -def _encode_state_dict(state_dict): +def _encode_state_dict( + state_dict: Optional[StateMap[str]], +) -> Optional[List[Tuple[str, str, str]]]: """Since dicts of (type, state_key) -> event_id cannot be serialized in JSON we need to convert them to a form that can. """ @@ -345,7 +353,9 @@ def _encode_state_dict(state_dict): return [(etype, state_key, v) for (etype, state_key), v in state_dict.items()] -def _decode_state_dict(input): +def _decode_state_dict( + input: Optional[List[Tuple[str, str, str]]] +) -> Optional[StateMap[str]]: """Decodes a state dict encoded using `_encode_state_dict` above""" if input is None: return None diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index ae4c8ab257..3134beb8d3 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -77,7 +77,7 @@ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[ ] -def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"): +def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None: """Wrapper that loads spam checkers configured using the old configuration, and registers the spam checker hooks they implement. """ @@ -129,9 +129,9 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"): request_info: Collection[Tuple[str, str]], auth_provider_id: Optional[str], ) -> Union[Awaitable[RegistrationBehaviour], RegistrationBehaviour]: - # We've already made sure f is not None above, but mypy doesn't - # do well across function boundaries so we need to tell it f is - # definitely not None. + # Assertion required because mypy can't prove we won't + # change `f` back to `None`. See + # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions assert f is not None return f( @@ -146,9 +146,10 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"): "Bad signature for callback check_registration_for_spam", ) - def run(*args, **kwargs): - # mypy doesn't do well across function boundaries so we need to tell it - # wrapped_func is definitely not None. + def run(*args: Any, **kwargs: Any) -> Awaitable: + # Assertion required because mypy can't prove we won't change `f` + # back to `None`. See + # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions assert wrapped_func is not None return maybe_awaitable(wrapped_func(*args, **kwargs)) @@ -165,7 +166,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"): class SpamChecker: - def __init__(self): + def __init__(self) -> None: self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = [] self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = [] self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = [] @@ -209,7 +210,7 @@ class SpamChecker: CHECK_REGISTRATION_FOR_SPAM_CALLBACK ] = None, check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None, - ): + ) -> None: """Register callbacks from module for each hook.""" if check_event_for_spam is not None: self._check_event_for_spam_callbacks.append(check_event_for_spam) @@ -275,7 +276,9 @@ class SpamChecker: return False - async def user_may_join_room(self, user_id: str, room_id: str, is_invited: bool): + async def user_may_join_room( + self, user_id: str, room_id: str, is_invited: bool + ) -> bool: """Checks if a given users is allowed to join a room. Not called when a user creates a room. @@ -285,7 +288,7 @@ class SpamChecker: is_invited: Whether the user is invited into the room Returns: - bool: Whether the user may join the room + Whether the user may join the room """ for callback in self._user_may_join_room_callbacks: if await callback(user_id, room_id, is_invited) is False: diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 976d9fa446..2a6dabdab6 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tuple from synapse.api.errors import SynapseError from synapse.events import EventBase @@ -38,7 +38,7 @@ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[ ] -def load_legacy_third_party_event_rules(hs: "HomeServer"): +def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: """Wrapper that loads a third party event rules module configured using the old configuration, and registers the hooks they implement. """ @@ -77,9 +77,9 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"): event: EventBase, state_events: StateMap[EventBase], ) -> Tuple[bool, Optional[dict]]: - # We've already made sure f is not None above, but mypy doesn't do well - # across function boundaries so we need to tell it f is definitely not - # None. + # Assertion required because mypy can't prove we won't change + # `f` back to `None`. See + # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions assert f is not None res = await f(event, state_events) @@ -98,9 +98,9 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"): async def wrap_on_create_room( requester: Requester, config: dict, is_requester_admin: bool ) -> None: - # We've already made sure f is not None above, but mypy doesn't do well - # across function boundaries so we need to tell it f is definitely not - # None. + # Assertion required because mypy can't prove we won't change + # `f` back to `None`. See + # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions assert f is not None res = await f(requester, config, is_requester_admin) @@ -112,9 +112,10 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"): return wrap_on_create_room - def run(*args, **kwargs): - # mypy doesn't do well across function boundaries so we need to tell it - # f is definitely not None. + def run(*args: Any, **kwargs: Any) -> Awaitable: + # Assertion required because mypy can't prove we won't change `f` + # back to `None`. See + # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions assert f is not None return maybe_awaitable(f(*args, **kwargs)) @@ -162,7 +163,7 @@ class ThirdPartyEventRules: check_visibility_can_be_modified: Optional[ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = None, - ): + ) -> None: """Register callbacks from modules for each hook.""" if check_event_allowed is not None: self._check_event_allowed_callbacks.append(check_event_allowed) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 520edbbf61..23bd24d963 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -13,18 +13,32 @@ # limitations under the License. import collections.abc import re -from typing import Any, Mapping, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Union, +) from frozendict import frozendict from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion +from synapse.types import JsonDict from synapse.util.async_helpers import yieldable_gather_results from synapse.util.frozenutils import unfreeze from . import EventBase +if TYPE_CHECKING: + from synapse.server import HomeServer + # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' # (? EventBase: return pruned_event -def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: +def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDict: """Redacts the event_dict in the same way as `prune_event`, except it operates on dicts rather than event objects @@ -97,7 +111,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: new_content = {} - def add_fields(*fields): + def add_fields(*fields: str) -> None: for field in fields: if field in event_dict["content"]: new_content[field] = event_dict["content"][field] @@ -151,7 +165,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: allowed_fields["content"] = new_content - unsigned = {} + unsigned: JsonDict = {} allowed_fields["unsigned"] = unsigned event_unsigned = event_dict.get("unsigned", {}) @@ -164,16 +178,16 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: return allowed_fields -def _copy_field(src, dst, field): +def _copy_field(src: JsonDict, dst: JsonDict, field: List[str]) -> None: """Copy the field in 'src' to 'dst'. For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"] then dst={"foo":{"bar":5}}. Args: - src(dict): The dict to read from. - dst(dict): The dict to modify. - field(list): List of keys to drill down to in 'src'. + src: The dict to read from. + dst: The dict to modify. + field: List of keys to drill down to in 'src'. """ if len(field) == 0: # this should be impossible return @@ -205,7 +219,7 @@ def _copy_field(src, dst, field): sub_out_dict[key_to_move] = sub_dict[key_to_move] -def only_fields(dictionary, fields): +def only_fields(dictionary: JsonDict, fields: List[str]) -> JsonDict: """Return a new dict with only the fields in 'dictionary' which are present in 'fields'. @@ -215,11 +229,11 @@ def only_fields(dictionary, fields): A literal '.' character in a field name may be escaped using a '\'. Args: - dictionary(dict): The dictionary to read from. - fields(list): A list of fields to copy over. Only shallow refs are + dictionary: The dictionary to read from. + fields: A list of fields to copy over. Only shallow refs are taken. Returns: - dict: A new dictionary with only the given fields. If fields was empty, + A new dictionary with only the given fields. If fields was empty, the same dictionary is returned. """ if len(fields) == 0: @@ -235,17 +249,17 @@ def only_fields(dictionary, fields): [f.replace(r"\.", r".") for f in field_array] for field_array in split_fields ] - output = {} + output: JsonDict = {} for field_array in split_fields: _copy_field(dictionary, output, field_array) return output -def format_event_raw(d): +def format_event_raw(d: JsonDict) -> JsonDict: return d -def format_event_for_client_v1(d): +def format_event_for_client_v1(d: JsonDict) -> JsonDict: d = format_event_for_client_v2(d) sender = d.get("sender") @@ -267,7 +281,7 @@ def format_event_for_client_v1(d): return d -def format_event_for_client_v2(d): +def format_event_for_client_v2(d: JsonDict) -> JsonDict: drop_keys = ( "auth_events", "prev_events", @@ -282,37 +296,37 @@ def format_event_for_client_v2(d): return d -def format_event_for_client_v2_without_room_id(d): +def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict: d = format_event_for_client_v2(d) d.pop("room_id", None) return d def serialize_event( - e, - time_now_ms, - as_client_event=True, - event_format=format_event_for_client_v1, - token_id=None, - only_event_fields=None, - include_stripped_room_state=False, -): + e: Union[JsonDict, EventBase], + time_now_ms: int, + as_client_event: bool = True, + event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1, + token_id: Optional[str] = None, + only_event_fields: Optional[List[str]] = None, + include_stripped_room_state: bool = False, +) -> JsonDict: """Serialize event for clients Args: - e (EventBase) - time_now_ms (int) - as_client_event (bool) + e + time_now_ms + as_client_event event_format token_id only_event_fields - include_stripped_room_state (bool): Some events can have stripped room state + include_stripped_room_state: Some events can have stripped room state stored in the `unsigned` field. This is required for invite and knock functionality. If this option is False, that state will be removed from the event before it is returned. Otherwise, it will be kept. Returns: - dict + The serialized event dictionary. """ # FIXME(erikj): To handle the case of presence events and the like @@ -369,25 +383,29 @@ class EventClientSerializer: clients. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.experimental_msc1849_support_enabled = ( hs.config.server.experimental_msc1849_support_enabled ) async def serialize_event( - self, event, time_now, bundle_aggregations=True, **kwargs - ): + self, + event: Union[JsonDict, EventBase], + time_now: int, + bundle_aggregations: bool = True, + **kwargs: Any, + ) -> JsonDict: """Serializes a single event. Args: - event (EventBase) - time_now (int): The current time in milliseconds - bundle_aggregations (bool): Whether to bundle in related events + event + time_now: The current time in milliseconds + bundle_aggregations: Whether to bundle in related events **kwargs: Arguments to pass to `serialize_event` Returns: - dict: The serialized event + The serialized event """ # To handle the case of presence events and the like if not isinstance(event, EventBase): @@ -448,25 +466,27 @@ class EventClientSerializer: return serialized_event - def serialize_events(self, events, time_now, **kwargs): + async def serialize_events( + self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any + ) -> List[JsonDict]: """Serializes multiple events. Args: - event (iter[EventBase]) - time_now (int): The current time in milliseconds + event + time_now: The current time in milliseconds **kwargs: Arguments to pass to `serialize_event` Returns: - Deferred[list[dict]]: The list of serialized events + The list of serialized events """ - return yieldable_gather_results( + return await yieldable_gather_results( self.serialize_event, events, time_now=time_now, **kwargs ) def copy_power_levels_contents( old_power_levels: Mapping[str, Union[int, Mapping[str, int]]] -): +) -> Dict[str, Union[int, Dict[str, int]]]: """Copy the content of a power_levels event, unfreezing frozendicts along the way Raises: @@ -475,7 +495,7 @@ def copy_power_levels_contents( if not isinstance(old_power_levels, collections.abc.Mapping): raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,)) - power_levels = {} + power_levels: Dict[str, Union[int, Dict[str, int]]] = {} for k, v in old_power_levels.items(): if isinstance(v, int): @@ -483,7 +503,8 @@ def copy_power_levels_contents( continue if isinstance(v, collections.abc.Mapping): - power_levels[k] = h = {} + h: Dict[str, int] = {} + power_levels[k] = h for k1, v1 in v.items(): # we should only have one level of nesting if not isinstance(v1, int): @@ -498,7 +519,7 @@ def copy_power_levels_contents( return power_levels -def validate_canonicaljson(value: Any): +def validate_canonicaljson(value: Any) -> None: """ Ensure that the JSON object is valid according to the rules of canonical JSON. diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 6eb6544c4c..4d459c17f1 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections.abc -from typing import Union +from typing import Iterable, Union import jsonschema @@ -28,11 +28,11 @@ from synapse.events.utils import ( validate_canonicaljson, ) from synapse.federation.federation_server import server_matches_acl_event -from synapse.types import EventID, RoomID, UserID +from synapse.types import EventID, JsonDict, RoomID, UserID class EventValidator: - def validate_new(self, event: EventBase, config: HomeServerConfig): + def validate_new(self, event: EventBase, config: HomeServerConfig) -> None: """Validates the event has roughly the right format Args: @@ -116,7 +116,7 @@ class EventValidator: errcode=Codes.BAD_JSON, ) - def _validate_retention(self, event: EventBase): + def _validate_retention(self, event: EventBase) -> None: """Checks that an event that defines the retention policy for a room respects the format enforced by the spec. @@ -156,7 +156,7 @@ class EventValidator: errcode=Codes.BAD_JSON, ) - def validate_builder(self, event: Union[EventBase, EventBuilder]): + def validate_builder(self, event: Union[EventBase, EventBuilder]) -> None: """Validates that the builder/event has roughly the right format. Only checks values that we expect a proto event to have, rather than all the fields an event would have @@ -204,14 +204,14 @@ class EventValidator: self._ensure_state_event(event) - def _ensure_strings(self, d, keys): + def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None: for s in keys: if s not in d: raise SynapseError(400, "'%s' not in content" % (s,)) if not isinstance(d[s], str): raise SynapseError(400, "'%s' not a string type" % (s,)) - def _ensure_state_event(self, event): + def _ensure_state_event(self, event: Union[EventBase, EventBuilder]) -> None: if not event.is_state(): raise SynapseError(400, "'%s' must be state events" % (event.type,)) @@ -244,7 +244,9 @@ POWER_LEVELS_SCHEMA = { } -def _create_power_level_validator(): +# This could return something newer than Draft 7, but that's the current "latest" +# validator. +def _create_power_level_validator() -> jsonschema.Draft7Validator: validator = jsonschema.validators.validator_for(POWER_LEVELS_SCHEMA) # by default jsonschema does not consider a frozendict to be an object so diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 7072bca1fc..6f39e9446f 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -465,17 +465,35 @@ class RoomCreationHandler: # the room has been created # Calculate the minimum power level needed to clone the room event_power_levels = power_levels.get("events", {}) + if not isinstance(event_power_levels, dict): + event_power_levels = {} state_default = power_levels.get("state_default", 50) + try: + state_default_int = int(state_default) # type: ignore[arg-type] + except (TypeError, ValueError): + state_default_int = 50 ban = power_levels.get("ban", 50) - needed_power_level = max(state_default, ban, max(event_power_levels.values())) + try: + ban = int(ban) # type: ignore[arg-type] + except (TypeError, ValueError): + ban = 50 + needed_power_level = max( + state_default_int, ban, max(event_power_levels.values()) + ) # Get the user's current power level, this matches the logic in get_user_power_level, # but without the entire state map. user_power_levels = power_levels.setdefault("users", {}) + if not isinstance(user_power_levels, dict): + user_power_levels = {} users_default = power_levels.get("users_default", 0) current_power_level = user_power_levels.get(user_id, users_default) + try: + current_power_level_int = int(current_power_level) # type: ignore[arg-type] + except (TypeError, ValueError): + current_power_level_int = 0 # Raise the requester's power level in the new room if necessary - if current_power_level < needed_power_level: + if current_power_level_int < needed_power_level: user_power_levels[user_id] = needed_power_level await self._send_events_for_new_room( diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 0b0711c03c..d695c18be2 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -232,12 +232,12 @@ class RelationPaginationServlet(RestServlet): # Similarly, we don't allow relations to be applied to relations, so we # return the original relations without any aggregations on top of them # here. - events = await self._event_serializer.serialize_events( + serialized_events = await self._event_serializer.serialize_events( events, now, bundle_aggregations=False ) return_value = pagination_chunk.to_dict() - return_value["chunk"] = events + return_value["chunk"] = serialized_events return_value["original_event"] = original_event return 200, return_value @@ -416,10 +416,10 @@ class RelationAggregationGroupPaginationServlet(RestServlet): ) now = self.clock.time_msec() - events = await self._event_serializer.serialize_events(events, now) + serialized_events = await self._event_serializer.serialize_events(events, now) return_value = result.to_dict() - return_value["chunk"] = events + return_value["chunk"] = serialized_events return 200, return_value