Implement knock feature (#6739)

This PR aims to implement the knock feature as proposed in https://github.com/matrix-org/matrix-doc/pull/2403

Signed-off-by: Sorunome mail@sorunome.de
Signed-off-by: Andrew Morgan andrewm@element.io
This commit is contained in:
Sorunome 2021-06-09 20:39:51 +02:00 committed by GitHub
parent 11846dff8c
commit d936371b69
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
29 changed files with 1614 additions and 119 deletions

1
changelog.d/6739.feature Normal file
View file

@ -0,0 +1 @@
Implement "room knocking" as per [MSC2403](https://github.com/matrix-org/matrix-doc/pull/2403). Contributed by Sorunome and anoa.

View file

@ -41,7 +41,7 @@ class Membership:
INVITE = "invite" INVITE = "invite"
JOIN = "join" JOIN = "join"
KNOCK = "knock" KNOCK = "xyz.amorgan.knock"
LEAVE = "leave" LEAVE = "leave"
BAN = "ban" BAN = "ban"
LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN) LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN)
@ -58,7 +58,7 @@ class PresenceState:
class JoinRules: class JoinRules:
PUBLIC = "public" PUBLIC = "public"
KNOCK = "knock" KNOCK = "xyz.amorgan.knock"
INVITE = "invite" INVITE = "invite"
PRIVATE = "private" PRIVATE = "private"
# As defined for MSC3083. # As defined for MSC3083.

View file

@ -449,7 +449,7 @@ class IncompatibleRoomVersionError(SynapseError):
super().__init__( super().__init__(
code=400, code=400,
msg="Your homeserver does not support the features required to " msg="Your homeserver does not support the features required to "
"join this room", "interact with this room",
errcode=Codes.INCOMPATIBLE_ROOM_VERSION, errcode=Codes.INCOMPATIBLE_ROOM_VERSION,
) )

View file

@ -56,7 +56,7 @@ class RoomVersion:
state_res = attr.ib(type=int) # one of the StateResolutionVersions state_res = attr.ib(type=int) # one of the StateResolutionVersions
enforce_key_validity = attr.ib(type=bool) enforce_key_validity = attr.ib(type=bool)
# Before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules # Before MSC2432, m.room.aliases had special auth rules and redaction rules
special_case_aliases_auth = attr.ib(type=bool) special_case_aliases_auth = attr.ib(type=bool)
# Strictly enforce canonicaljson, do not allow: # Strictly enforce canonicaljson, do not allow:
# * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1] # * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1]
@ -70,6 +70,9 @@ class RoomVersion:
msc2176_redaction_rules = attr.ib(type=bool) msc2176_redaction_rules = attr.ib(type=bool)
# MSC3083: Support the 'restricted' join_rule. # MSC3083: Support the 'restricted' join_rule.
msc3083_join_rules = attr.ib(type=bool) msc3083_join_rules = attr.ib(type=bool)
# MSC2403: Allows join_rules to be set to 'knock', changes auth rules to allow sending
# m.room.membership event with membership 'knock'.
msc2403_knocking = attr.ib(type=bool)
class RoomVersions: class RoomVersions:
@ -84,6 +87,7 @@ class RoomVersions:
limit_notifications_power_levels=False, limit_notifications_power_levels=False,
msc2176_redaction_rules=False, msc2176_redaction_rules=False,
msc3083_join_rules=False, msc3083_join_rules=False,
msc2403_knocking=False,
) )
V2 = RoomVersion( V2 = RoomVersion(
"2", "2",
@ -96,6 +100,7 @@ class RoomVersions:
limit_notifications_power_levels=False, limit_notifications_power_levels=False,
msc2176_redaction_rules=False, msc2176_redaction_rules=False,
msc3083_join_rules=False, msc3083_join_rules=False,
msc2403_knocking=False,
) )
V3 = RoomVersion( V3 = RoomVersion(
"3", "3",
@ -108,6 +113,7 @@ class RoomVersions:
limit_notifications_power_levels=False, limit_notifications_power_levels=False,
msc2176_redaction_rules=False, msc2176_redaction_rules=False,
msc3083_join_rules=False, msc3083_join_rules=False,
msc2403_knocking=False,
) )
V4 = RoomVersion( V4 = RoomVersion(
"4", "4",
@ -120,6 +126,7 @@ class RoomVersions:
limit_notifications_power_levels=False, limit_notifications_power_levels=False,
msc2176_redaction_rules=False, msc2176_redaction_rules=False,
msc3083_join_rules=False, msc3083_join_rules=False,
msc2403_knocking=False,
) )
V5 = RoomVersion( V5 = RoomVersion(
"5", "5",
@ -132,6 +139,7 @@ class RoomVersions:
limit_notifications_power_levels=False, limit_notifications_power_levels=False,
msc2176_redaction_rules=False, msc2176_redaction_rules=False,
msc3083_join_rules=False, msc3083_join_rules=False,
msc2403_knocking=False,
) )
V6 = RoomVersion( V6 = RoomVersion(
"6", "6",
@ -144,6 +152,7 @@ class RoomVersions:
limit_notifications_power_levels=True, limit_notifications_power_levels=True,
msc2176_redaction_rules=False, msc2176_redaction_rules=False,
msc3083_join_rules=False, msc3083_join_rules=False,
msc2403_knocking=False,
) )
MSC2176 = RoomVersion( MSC2176 = RoomVersion(
"org.matrix.msc2176", "org.matrix.msc2176",
@ -156,6 +165,7 @@ class RoomVersions:
limit_notifications_power_levels=True, limit_notifications_power_levels=True,
msc2176_redaction_rules=True, msc2176_redaction_rules=True,
msc3083_join_rules=False, msc3083_join_rules=False,
msc2403_knocking=False,
) )
MSC3083 = RoomVersion( MSC3083 = RoomVersion(
"org.matrix.msc3083", "org.matrix.msc3083",
@ -168,6 +178,20 @@ class RoomVersions:
limit_notifications_power_levels=True, limit_notifications_power_levels=True,
msc2176_redaction_rules=False, msc2176_redaction_rules=False,
msc3083_join_rules=True, msc3083_join_rules=True,
msc2403_knocking=False,
)
MSC2403 = RoomVersion(
"xyz.amorgan.knock",
RoomDisposition.UNSTABLE,
EventFormatVersions.V3,
StateResolutionVersions.V2,
enforce_key_validity=True,
special_case_aliases_auth=False,
strict_canonicaljson=True,
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
msc2403_knocking=True,
) )
@ -183,4 +207,5 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.MSC2176, RoomVersions.MSC2176,
RoomVersions.MSC3083, RoomVersions.MSC3083,
) )
# Note that we do not include MSC2043 here unless it is enabled in the config.
} # type: Dict[str, RoomVersion] } # type: Dict[str, RoomVersion]

View file

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
from prometheus_client import Counter from prometheus_client import Counter
from synapse.api.constants import EventTypes, ThirdPartyEntityKind from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
@ -247,9 +247,14 @@ class ApplicationServiceApi(SimpleHttpClient):
e, e,
time_now, time_now,
as_client_event=True, as_client_event=True,
is_invite=( # If this is an invite or a knock membership event, and we're interested
# in this user, then include any stripped state alongside the event.
include_stripped_room_state=(
e.type == EventTypes.Member e.type == EventTypes.Member
and e.membership == "invite" and (
e.membership == Membership.INVITE
or e.membership == Membership.KNOCK
)
and service.is_interested_in_user(e.state_key) and service.is_interested_in_user(e.state_key)
), ),
) )

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright 2020 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View file

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.config._base import Config from synapse.config._base import Config
from synapse.types import JsonDict from synapse.types import JsonDict
@ -29,3 +30,9 @@ class ExperimentalConfig(Config):
# MSC3026 (busy presence state) # MSC3026 (busy presence state)
self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool
# MSC2403 (room knocking)
self.msc2403_enabled = experimental.get("msc2403_enabled", False) # type: bool
if self.msc2403_enabled:
# Enable the MSC2403 unstable room version
KNOWN_ROOM_VERSIONS[RoomVersions.MSC2403.identifier] = RoomVersions.MSC2403

View file

@ -160,6 +160,7 @@ def check(
if logger.isEnabledFor(logging.DEBUG): if logger.isEnabledFor(logging.DEBUG):
logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()]) logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()])
# 5. If type is m.room.membership
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
_is_membership_change_allowed(room_version_obj, event, auth_events) _is_membership_change_allowed(room_version_obj, event, auth_events)
logger.debug("Allowing! %s", event) logger.debug("Allowing! %s", event)
@ -257,6 +258,11 @@ def _is_membership_change_allowed(
caller_in_room = caller and caller.membership == Membership.JOIN caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE caller_invited = caller and caller.membership == Membership.INVITE
caller_knocked = (
caller
and room_version.msc2403_knocking
and caller.membership == Membership.KNOCK
)
# get info about the target # get info about the target
key = (EventTypes.Member, target_user_id) key = (EventTypes.Member, target_user_id)
@ -283,6 +289,7 @@ def _is_membership_change_allowed(
{ {
"caller_in_room": caller_in_room, "caller_in_room": caller_in_room,
"caller_invited": caller_invited, "caller_invited": caller_invited,
"caller_knocked": caller_knocked,
"target_banned": target_banned, "target_banned": target_banned,
"target_in_room": target_in_room, "target_in_room": target_in_room,
"membership": membership, "membership": membership,
@ -299,9 +306,14 @@ def _is_membership_change_allowed(
raise AuthError(403, "%s is banned from the room" % (target_user_id,)) raise AuthError(403, "%s is banned from the room" % (target_user_id,))
return return
if Membership.JOIN != membership: # Require the user to be in the room for membership changes other than join/knock.
if Membership.JOIN != membership and (
RoomVersion.msc2403_knocking and Membership.KNOCK != membership
):
# If the user has been invited or has knocked, they are allowed to change their
# membership event to leave
if ( if (
caller_invited (caller_invited or caller_knocked)
and Membership.LEAVE == membership and Membership.LEAVE == membership
and target_user_id == event.user_id and target_user_id == event.user_id
): ):
@ -339,7 +351,9 @@ def _is_membership_change_allowed(
and join_rule == JoinRules.MSC3083_RESTRICTED and join_rule == JoinRules.MSC3083_RESTRICTED
): ):
pass pass
elif join_rule == JoinRules.INVITE: elif join_rule == JoinRules.INVITE or (
room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
):
if not caller_in_room and not caller_invited: if not caller_in_room and not caller_invited:
raise AuthError(403, "You are not invited to this room.") raise AuthError(403, "You are not invited to this room.")
else: else:
@ -358,6 +372,17 @@ def _is_membership_change_allowed(
elif Membership.BAN == membership: elif Membership.BAN == membership:
if user_level < ban_level or user_level <= target_level: if user_level < ban_level or user_level <= target_level:
raise AuthError(403, "You don't have permission to ban") raise AuthError(403, "You don't have permission to ban")
elif room_version.msc2403_knocking and Membership.KNOCK == membership:
if join_rule != JoinRules.KNOCK:
raise AuthError(403, "You don't have permission to knock")
elif target_user_id != event.user_id:
raise AuthError(403, "You cannot knock for other users")
elif target_in_room:
raise AuthError(403, "You cannot knock on a room you are already in")
elif caller_invited:
raise AuthError(403, "You are already invited to this room")
elif target_banned:
raise AuthError(403, "You are banned from this room")
else: else:
raise AuthError(500, "Unknown membership %s" % membership) raise AuthError(500, "Unknown membership %s" % membership)
@ -718,7 +743,7 @@ def auth_types_for_event(event: EventBase) -> Set[Tuple[str, str]]:
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
membership = event.content["membership"] membership = event.content["membership"]
if membership in [Membership.JOIN, Membership.INVITE]: if membership in [Membership.JOIN, Membership.INVITE, Membership.KNOCK]:
auth_types.add((EventTypes.JoinRules, "")) auth_types.add((EventTypes.JoinRules, ""))
auth_types.add((EventTypes.Member, event.state_key)) auth_types.add((EventTypes.Member, event.state_key))

View file

@ -242,6 +242,7 @@ def format_event_for_client_v1(d):
"replaces_state", "replaces_state",
"prev_content", "prev_content",
"invite_room_state", "invite_room_state",
"knock_room_state",
) )
for key in copy_keys: for key in copy_keys:
if key in d["unsigned"]: if key in d["unsigned"]:
@ -278,7 +279,7 @@ def serialize_event(
event_format=format_event_for_client_v1, event_format=format_event_for_client_v1,
token_id=None, token_id=None,
only_event_fields=None, only_event_fields=None,
is_invite=False, include_stripped_room_state=False,
): ):
"""Serialize event for clients """Serialize event for clients
@ -289,8 +290,10 @@ def serialize_event(
event_format event_format
token_id token_id
only_event_fields only_event_fields
is_invite (bool): Whether this is an invite that is being sent to the include_stripped_room_state (bool): Some events can have stripped room state
invitee stored in the `unsigned` field. This is required for invite and knock
functionality. If this option is False, that state will be removed from the
event before it is returned. Otherwise, it will be kept.
Returns: Returns:
dict dict
@ -322,11 +325,13 @@ def serialize_event(
if txn_id is not None: if txn_id is not None:
d["unsigned"]["transaction_id"] = txn_id d["unsigned"]["transaction_id"] = txn_id
# If this is an invite for somebody else, then we don't care about the # invite_room_state and knock_room_state are a list of stripped room state events
# invite_room_state as that's meant solely for the invitee. Other clients # that are meant to provide metadata about a room to an invitee/knocker. They are
# will already have the state since they're in the room. # intended to only be included in specific circumstances, such as down sync, and
if not is_invite: # should not be included in any other case.
if not include_stripped_room_state:
d["unsigned"].pop("invite_room_state", None) d["unsigned"].pop("invite_room_state", None)
d["unsigned"].pop("knock_room_state", None)
if as_client_event: if as_client_event:
d = event_format(d) d = event_format(d)

View file

@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015-2021 The Matrix.org Foundation C.I.C.
# Copyright 2020 Sorunome
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -89,6 +90,7 @@ class FederationClient(FederationBase):
self._clock.looping_call(self._clear_tried_cache, 60 * 1000) self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client() self.transport_layer = hs.get_federation_transport_client()
self._msc2403_enabled = hs.config.experimental.msc2403_enabled
self.hostname = hs.hostname self.hostname = hs.hostname
self.signing_key = hs.signing_key self.signing_key = hs.signing_key
@ -620,6 +622,11 @@ class FederationClient(FederationBase):
no servers successfully handle the request. no servers successfully handle the request.
""" """
valid_memberships = {Membership.JOIN, Membership.LEAVE} valid_memberships = {Membership.JOIN, Membership.LEAVE}
# Allow knocking if the feature is enabled
if self._msc2403_enabled:
valid_memberships.add(Membership.KNOCK)
if membership not in valid_memberships: if membership not in valid_memberships:
raise RuntimeError( raise RuntimeError(
"make_membership_event called with membership='%s', must be one of %s" "make_membership_event called with membership='%s', must be one of %s"
@ -638,6 +645,13 @@ class FederationClient(FederationBase):
if not room_version: if not room_version:
raise UnsupportedRoomVersionError() raise UnsupportedRoomVersionError()
if not room_version.msc2403_knocking and membership == Membership.KNOCK:
raise SynapseError(
400,
"This room version does not support knocking",
errcode=Codes.FORBIDDEN,
)
pdu_dict = ret.get("event", None) pdu_dict = ret.get("event", None)
if not isinstance(pdu_dict, dict): if not isinstance(pdu_dict, dict):
raise InvalidResponseError("Bad 'event' field in response") raise InvalidResponseError("Bad 'event' field in response")
@ -946,6 +960,62 @@ class FederationClient(FederationBase):
# content. # content.
return resp[1] return resp[1]
async def send_knock(self, destinations: List[str], pdu: EventBase) -> JsonDict:
"""Attempts to send a knock event to given a list of servers. Iterates
through the list until one attempt succeeds.
Doing so will cause the remote server to add the event to the graph,
and send the event out to the rest of the federation.
Args:
destinations: A list of candidate homeservers which are likely to be
participating in the room.
pdu: The event to be sent.
Returns:
The remote homeserver return some state from the room. The response
dictionary is in the form:
{"knock_state_events": [<state event dict>, ...]}
The list of state events may be empty.
Raises:
SynapseError: If the chosen remote server returns a 3xx/4xx code.
RuntimeError: If no servers were reachable.
"""
async def send_request(destination: str) -> JsonDict:
return await self._do_send_knock(destination, pdu)
return await self._try_destination_list(
"xyz.amorgan.knock/send_knock", destinations, send_request
)
async def _do_send_knock(self, destination: str, pdu: EventBase) -> JsonDict:
"""Send a knock event to a remote homeserver.
Args:
destination: The homeserver to send to.
pdu: The event to send.
Returns:
The remote homeserver can optionally return some state from the room. The response
dictionary is in the form:
{"knock_state_events": [<state event dict>, ...]}
The list of state events may be empty.
"""
time_now = self._clock.time_msec()
return await self.transport_layer.send_knock_v1(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
async def get_public_rooms( async def get_public_rooms(
self, self,
remote_server: str, remote_server: str,

View file

@ -138,6 +138,8 @@ class FederationServer(FederationBase):
hs.config.federation.federation_metrics_domains hs.config.federation.federation_metrics_domains
) )
self._room_prejoin_state_types = hs.config.api.room_prejoin_state
async def on_backfill_request( async def on_backfill_request(
self, origin: str, room_id: str, versions: List[str], limit: int self, origin: str, room_id: str, versions: List[str], limit: int
) -> Tuple[int, Dict[str, Any]]: ) -> Tuple[int, Dict[str, Any]]:
@ -586,6 +588,103 @@ class FederationServer(FederationBase):
await self.handler.on_send_leave_request(origin, pdu) await self.handler.on_send_leave_request(origin, pdu)
return {} return {}
async def on_make_knock_request(
self, origin: str, room_id: str, user_id: str, supported_versions: List[str]
) -> Dict[str, Union[EventBase, str]]:
"""We've received a /make_knock/ request, so we create a partial knock
event for the room and hand that back, along with the room version, to the knocking
homeserver. We do *not* persist or process this event until the other server has
signed it and sent it back.
Args:
origin: The (verified) server name of the requesting server.
room_id: The room to create the knock event in.
user_id: The user to create the knock for.
supported_versions: The room versions supported by the requesting server.
Returns:
The partial knock event.
"""
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
room_version = await self.store.get_room_version(room_id)
# Check that this room version is supported by the remote homeserver
if room_version.identifier not in supported_versions:
logger.warning(
"Room version %s not in %s", room_version.identifier, supported_versions
)
raise IncompatibleRoomVersionError(room_version=room_version.identifier)
# Check that this room supports knocking as defined by its room version
if not room_version.msc2403_knocking:
raise SynapseError(
403,
"This room version does not support knocking",
errcode=Codes.FORBIDDEN,
)
pdu = await self.handler.on_make_knock_request(origin, room_id, user_id)
time_now = self._clock.time_msec()
return {
"event": pdu.get_pdu_json(time_now),
"room_version": room_version.identifier,
}
async def on_send_knock_request(
self,
origin: str,
content: JsonDict,
room_id: str,
) -> Dict[str, List[JsonDict]]:
"""
We have received a knock event for a room. Verify and send the event into the room
on the knocking homeserver's behalf. Then reply with some stripped state from the
room for the knockee.
Args:
origin: The remote homeserver of the knocking user.
content: The content of the request.
room_id: The ID of the room to knock on.
Returns:
The stripped room state.
"""
logger.debug("on_send_knock_request: content: %s", content)
room_version = await self.store.get_room_version(room_id)
# Check that this room supports knocking as defined by its room version
if not room_version.msc2403_knocking:
raise SynapseError(
403,
"This room version does not support knocking",
errcode=Codes.FORBIDDEN,
)
pdu = event_from_pdu_json(content, room_version)
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, pdu.room_id)
logger.debug("on_send_knock_request: pdu sigs: %s", pdu.signatures)
pdu = await self._check_sigs_and_hash(room_version, pdu)
# Handle the event, and retrieve the EventContext
event_context = await self.handler.on_send_knock_request(origin, pdu)
# Retrieve stripped state events from the room and send them back to the remote
# server. This will allow the remote server's clients to display information
# related to the room while the knock request is pending.
stripped_room_state = (
await self.store.get_stripped_room_state_from_event_context(
event_context, self._room_prejoin_state_types
)
)
return {"knock_state_events": stripped_room_state}
async def on_event_auth( async def on_event_auth(
self, origin: str, room_id: str, event_id: str self, origin: str, room_id: str, event_id: str
) -> Tuple[int, Dict[str, Any]]: ) -> Tuple[int, Dict[str, Any]]:

View file

@ -1,5 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# Copyright 2018 New Vector Ltd # Copyright 2020 Sorunome
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -47,6 +47,7 @@ class TransportLayerClient:
def __init__(self, hs): def __init__(self, hs):
self.server_name = hs.hostname self.server_name = hs.hostname
self.client = hs.get_federation_http_client() self.client = hs.get_federation_http_client()
self._msc2403_enabled = hs.config.experimental.msc2403_enabled
@log_function @log_function
def get_room_state_ids(self, destination, room_id, event_id): def get_room_state_ids(self, destination, room_id, event_id):
@ -221,11 +222,27 @@ class TransportLayerClient:
is not in our federation whitelist is not in our federation whitelist
""" """
valid_memberships = {Membership.JOIN, Membership.LEAVE} valid_memberships = {Membership.JOIN, Membership.LEAVE}
# Allow knocking if the feature is enabled
if self._msc2403_enabled:
valid_memberships.add(Membership.KNOCK)
if membership not in valid_memberships: if membership not in valid_memberships:
raise RuntimeError( raise RuntimeError(
"make_membership_event called with membership='%s', must be one of %s" "make_membership_event called with membership='%s', must be one of %s"
% (membership, ",".join(valid_memberships)) % (membership, ",".join(valid_memberships))
) )
# Knock currently uses an unstable prefix
if membership == Membership.KNOCK:
# Create a path in the form of /unstable/xyz.amorgan.knock/make_knock/...
path = _create_path(
FEDERATION_UNSTABLE_PREFIX + "/xyz.amorgan.knock",
"/make_knock/%s/%s",
room_id,
user_id,
)
else:
path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id) path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id)
ignore_backoff = False ignore_backoff = False
@ -321,6 +338,45 @@ class TransportLayerClient:
return response return response
@log_function
async def send_knock_v1(
self,
destination: str,
room_id: str,
event_id: str,
content: JsonDict,
) -> JsonDict:
"""
Sends a signed knock membership event to a remote server. This is the second
step for knocking after make_knock.
Args:
destination: The remote homeserver.
room_id: The ID of the room to knock on.
event_id: The ID of the knock membership event that we're sending.
content: The knock membership event that we're sending. Note that this is not the
`content` field of the membership event, but the entire signed membership event
itself represented as a JSON dict.
Returns:
The remote homeserver can optionally return some state from the room. The response
dictionary is in the form:
{"knock_state_events": [<state event dict>, ...]}
The list of state events may be empty.
"""
path = _create_path(
FEDERATION_UNSTABLE_PREFIX + "/xyz.amorgan.knock",
"/send_knock/%s/%s",
room_id,
event_id,
)
return await self.client.put_json(
destination=destination, path=path, data=content
)
@log_function @log_function
async def send_invite_v1(self, destination, room_id, event_id, content): async def send_invite_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/invite/%s/%s", room_id, event_id) path = _create_v1_path("/invite/%s/%s", room_id, event_id)

View file

@ -1,6 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# Copyright 2018 New Vector Ltd # Copyright 2020 Sorunome
# Copyright 2019 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import functools import functools
import logging import logging
import re import re
@ -35,6 +33,7 @@ from synapse.http.servlet import (
parse_integer_from_args, parse_integer_from_args,
parse_json_object_from_request, parse_json_object_from_request,
parse_string_from_args, parse_string_from_args,
parse_strings_from_args,
) )
from synapse.logging.context import run_in_background from synapse.logging.context import run_in_background
from synapse.logging.opentracing import ( from synapse.logging.opentracing import (
@ -565,6 +564,34 @@ class FederationV2SendLeaveServlet(BaseFederationServerServlet):
return 200, content return 200, content
class FederationMakeKnockServlet(BaseFederationServerServlet):
PATH = "/make_knock/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
PREFIX = FEDERATION_UNSTABLE_PREFIX + "/xyz.amorgan.knock"
async def on_GET(self, origin, content, query, room_id, user_id):
try:
# Retrieve the room versions the remote homeserver claims to support
supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8")
except KeyError:
raise SynapseError(400, "Missing required query parameter 'ver'")
content = await self.handler.on_make_knock_request(
origin, room_id, user_id, supported_versions=supported_versions
)
return 200, content
class FederationV1SendKnockServlet(BaseFederationServerServlet):
PATH = "/send_knock/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_UNSTABLE_PREFIX + "/xyz.amorgan.knock"
async def on_PUT(self, origin, content, query, room_id, event_id):
content = await self.handler.on_send_knock_request(origin, content, room_id)
return 200, content
class FederationEventAuthServlet(BaseFederationServerServlet): class FederationEventAuthServlet(BaseFederationServerServlet):
PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
@ -1624,6 +1651,13 @@ GROUP_ATTESTATION_SERVLET_CLASSES = (
FederationGroupsRenewAttestaionServlet, FederationGroupsRenewAttestaionServlet,
) # type: Tuple[Type[BaseFederationServlet], ...] ) # type: Tuple[Type[BaseFederationServlet], ...]
MSC2403_SERVLET_CLASSES = (
FederationV1SendKnockServlet,
FederationMakeKnockServlet,
)
DEFAULT_SERVLET_GROUPS = ( DEFAULT_SERVLET_GROUPS = (
"federation", "federation",
"room_list", "room_list",
@ -1666,6 +1700,16 @@ def register_servlets(
server_name=hs.hostname, server_name=hs.hostname,
).register(resource) ).register(resource)
# Register msc2403 (knocking) servlets if the feature is enabled
if hs.config.experimental.msc2403_enabled:
for servletclass in MSC2403_SERVLET_CLASSES:
servletclass(
hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)
if "openid" in servlet_groups: if "openid" in servlet_groups:
for servletclass in OPENID_SERVLET_CLASSES: for servletclass in OPENID_SERVLET_CLASSES:
servletclass( servletclass(

View file

@ -1,6 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# Copyright 2017-2018 New Vector Ltd # Copyright 2020 Sorunome
# Copyright 2019 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -1550,6 +1549,77 @@ class FederationHandler(BaseHandler):
run_in_background(self._handle_queued_pdus, room_queue) run_in_background(self._handle_queued_pdus, room_queue)
@log_function
async def do_knock(
self,
target_hosts: List[str],
room_id: str,
knockee: str,
content: JsonDict,
) -> Tuple[str, int]:
"""Sends the knock to the remote server.
This first triggers a make_knock request that returns a partial
event that we can fill out and sign. This is then sent to the
remote server via send_knock.
Knock events must be signed by the knockee's server before distributing.
Args:
target_hosts: A list of hosts that we want to try knocking through.
room_id: The ID of the room to knock on.
knockee: The ID of the user who is knocking.
content: The content of the knock event.
Returns:
A tuple of (event ID, stream ID).
Raises:
SynapseError: If the chosen remote server returns a 3xx/4xx code.
RuntimeError: If no servers were reachable.
"""
logger.debug("Knocking on room %s on behalf of user %s", room_id, knockee)
# Inform the remote server of the room versions we support
supported_room_versions = list(KNOWN_ROOM_VERSIONS.keys())
# Ask the remote server to create a valid knock event for us. Once received,
# we sign the event
params = {"ver": supported_room_versions} # type: Dict[str, Iterable[str]]
origin, event, event_format_version = await self._make_and_verify_event(
target_hosts, room_id, knockee, Membership.KNOCK, content, params=params
)
# Record the room ID and its version so that we have a record of the room
await self._maybe_store_room_on_outlier_membership(
room_id=event.room_id, room_version=event_format_version
)
# Initially try the host that we successfully called /make_knock on
try:
target_hosts.remove(origin)
target_hosts.insert(0, origin)
except ValueError:
pass
# Send the signed event back to the room, and potentially receive some
# further information about the room in the form of partial state events
stripped_room_state = await self.federation_client.send_knock(
target_hosts, event
)
# Store any stripped room state events in the "unsigned" key of the event.
# This is a bit of a hack and is cribbing off of invites. Basically we
# store the room state here and retrieve it again when this event appears
# in the invitee's sync stream. It is stripped out for all other local users.
event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
context = await self.state_handler.compute_event_context(event)
stream_id = await self.persist_events_and_notify(
event.room_id, [(event, context)]
)
return event.event_id, stream_id
async def _handle_queued_pdus( async def _handle_queued_pdus(
self, room_queue: List[Tuple[EventBase, str]] self, room_queue: List[Tuple[EventBase, str]]
) -> None: ) -> None:
@ -1915,6 +1985,116 @@ class FederationHandler(BaseHandler):
return None return None
@log_function
async def on_make_knock_request(
self, origin: str, room_id: str, user_id: str
) -> EventBase:
"""We've received a make_knock request, so we create a partial
knock event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back.
Args:
origin: The (verified) server name of the requesting server.
room_id: The room to create the knock event in.
user_id: The user to create the knock for.
Returns:
The partial knock event.
"""
if get_domain_from_id(user_id) != origin:
logger.info(
"Get /xyz.amorgan.knock/make_knock request for user %r"
"from different origin %s, ignoring",
user_id,
origin,
)
raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
room_version = await self.store.get_room_version_id(room_id)
builder = self.event_builder_factory.new(
room_version,
{
"type": EventTypes.Member,
"content": {"membership": Membership.KNOCK},
"room_id": room_id,
"sender": user_id,
"state_key": user_id,
},
)
event, context = await self.event_creation_handler.create_new_client_event(
builder=builder
)
event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
logger.warning("Creation of knock %s forbidden by third-party rules", event)
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
try:
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_knock_request`
await self.auth.check_from_context(
room_version, event, context, do_sig_check=False
)
except AuthError as e:
logger.warning("Failed to create new knock %r because %s", event, e)
raise e
return event
@log_function
async def on_send_knock_request(
self, origin: str, event: EventBase
) -> EventContext:
"""
We have received a knock event for a room. Verify that event and send it into the room
on the knocking homeserver's behalf.
Args:
origin: The remote homeserver of the knocking user.
event: The knocking member event that has been signed by the remote homeserver.
Returns:
The context of the event after inserting it into the room graph.
"""
logger.debug(
"on_send_knock_request: Got event: %s, signatures: %s",
event.event_id,
event.signatures,
)
if get_domain_from_id(event.sender) != origin:
logger.info(
"Got /xyz.amorgan.knock/send_knock request for user %r "
"from different origin %s",
event.sender,
origin,
)
raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
event.internal_metadata.outlier = False
context = await self.state_handler.compute_event_context(event)
await self._auth_and_persist_event(origin, event, context)
event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
logger.info("Sending of knock %s forbidden by third-party rules", event)
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
return context
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]: async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
"""Returns the state at the event. i.e. not including said event.""" """Returns the state at the event. i.e. not including said event."""

View file

@ -1,6 +1,7 @@
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd # Copyright 2017-2018 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright 2019-2020 The Matrix.org Foundation C.I.C.
# Copyrignt 2020 Sorunome
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -398,13 +399,14 @@ class EventCreationHandler:
self._events_shard_config = self.config.worker.events_shard_config self._events_shard_config = self.config.worker.events_shard_config
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
self.room_invite_state_types = self.hs.config.api.room_prejoin_state self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
self.membership_types_to_include_profile_data_in = ( self.membership_types_to_include_profile_data_in = {
{Membership.JOIN, Membership.INVITE} Membership.JOIN,
if self.hs.config.include_profile_data_on_invite Membership.KNOCK,
else {Membership.JOIN} }
) if self.hs.config.include_profile_data_on_invite:
self.membership_types_to_include_profile_data_in.add(Membership.INVITE)
self.send_event = ReplicationSendEventRestServlet.make_client(hs) self.send_event = ReplicationSendEventRestServlet.make_client(hs)
@ -961,8 +963,8 @@ class EventCreationHandler:
room_version = await self.store.get_room_version_id(event.room_id) room_version = await self.store.get_room_version_id(event.room_id)
if event.internal_metadata.is_out_of_band_membership(): if event.internal_metadata.is_out_of_band_membership():
# the only sort of out-of-band-membership events we expect to see here # the only sort of out-of-band-membership events we expect to see here are
# are invite rejections we have generated ourselves. # invite rejections and rescinded knocks that we have generated ourselves.
assert event.type == EventTypes.Member assert event.type == EventTypes.Member
assert event.content["membership"] == Membership.LEAVE assert event.content["membership"] == Membership.LEAVE
else: else:
@ -1239,7 +1241,7 @@ class EventCreationHandler:
"invite_room_state" "invite_room_state"
] = await self.store.get_stripped_room_state_from_event_context( ] = await self.store.get_stripped_room_state_from_event_context(
context, context,
self.room_invite_state_types, self.room_prejoin_state_types,
membership_user_id=event.sender, membership_user_id=event.sender,
) )
@ -1257,6 +1259,14 @@ class EventCreationHandler:
# TODO: Make sure the signatures actually are correct. # TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures) event.signatures.update(returned_invite.signatures)
if event.content["membership"] == Membership.KNOCK:
event.unsigned[
"knock_room_state"
] = await self.store.get_stripped_room_state_from_event_context(
context,
self.room_prejoin_state_types,
)
if event.type == EventTypes.Redaction: if event.type == EventTypes.Redaction:
original_event = await self.store.get_event( original_event = await self.store.get_event(
event.redacts, event.redacts,

View file

@ -1,4 +1,5 @@
# Copyright 2016-2020 The Matrix.org Foundation C.I.C. # Copyright 2016-2020 The Matrix.org Foundation C.I.C.
# Copyright 2020 Sorunome
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -11,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc import abc
import logging import logging
import random import random
@ -30,7 +30,15 @@ from synapse.api.errors import (
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID from synapse.types import (
JsonDict,
Requester,
RoomAlias,
RoomID,
StateMap,
UserID,
get_domain_from_id,
)
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_left_room from synapse.util.distributor import user_left_room
@ -125,6 +133,24 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
""" """
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod
async def remote_knock(
self,
remote_room_hosts: List[str],
room_id: str,
user: UserID,
content: dict,
) -> Tuple[str, int]:
"""Try and knock on a room that this server is not in
Args:
remote_room_hosts: List of servers that can be used to knock via.
room_id: Room that we are trying to knock on.
user: User who is trying to knock.
content: A dict that should be used as the content of the knock event.
"""
raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod
async def remote_reject_invite( async def remote_reject_invite(
self, self,
@ -148,6 +174,27 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
""" """
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod
async def remote_rescind_knock(
self,
knock_event_id: str,
txn_id: Optional[str],
requester: Requester,
content: JsonDict,
) -> Tuple[str, int]:
"""Rescind a local knock made on a remote room.
Args:
knock_event_id: The ID of the knock event to rescind.
txn_id: An optional transaction ID supplied by the client.
requester: The user making the request, according to the access token.
content: The content of the generated leave event.
Returns:
A tuple containing (event_id, stream_id of the leave event).
"""
raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod
async def _user_left_room(self, target: UserID, room_id: str) -> None: async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Notifies distributor on master process that the user has left the """Notifies distributor on master process that the user has left the
@ -603,29 +650,31 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
elif effective_membership_state == Membership.LEAVE: elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room: if not is_host_in_room:
# perhaps we've been invited # Figure out the user's current membership state for the room
( (
current_membership_type, current_membership_type,
current_membership_event_id, current_membership_event_id,
) = await self.store.get_local_current_membership_for_user_in_room( ) = await self.store.get_local_current_membership_for_user_in_room(
target.to_string(), room_id target.to_string(), room_id
) )
if ( if not current_membership_type or not current_membership_event_id:
current_membership_type != Membership.INVITE
or not current_membership_event_id
):
logger.info( logger.info(
"%s sent a leave request to %s, but that is not an active room " "%s sent a leave request to %s, but that is not an active room "
"on this server, and there is no pending invite", "on this server, or there is no pending invite or knock",
target, target,
room_id, room_id,
) )
raise SynapseError(404, "Not a known room") raise SynapseError(404, "Not a known room")
# perhaps we've been invited
if current_membership_type == Membership.INVITE:
invite = await self.store.get_event(current_membership_event_id) invite = await self.store.get_event(current_membership_event_id)
logger.info( logger.info(
"%s rejects invite to %s from %s", target, room_id, invite.sender "%s rejects invite to %s from %s",
target,
room_id,
invite.sender,
) )
if not self.hs.is_mine_id(invite.sender): if not self.hs.is_mine_id(invite.sender):
@ -651,6 +700,33 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if len(latest_event_ids) == 0: if len(latest_event_ids) == 0:
latest_event_ids = [invite.event_id] latest_event_ids = [invite.event_id]
# or perhaps this is a remote room that a local user has knocked on
elif current_membership_type == Membership.KNOCK:
knock = await self.store.get_event(current_membership_event_id)
return await self.remote_rescind_knock(
knock.event_id, txn_id, requester, content
)
elif (
self.config.experimental.msc2403_enabled
and effective_membership_state == Membership.KNOCK
):
if not is_host_in_room:
# The knock needs to be sent over federation instead
remote_room_hosts.append(get_domain_from_id(room_id))
content["membership"] = Membership.KNOCK
profile = self.profile_handler
if "displayname" not in content:
content["displayname"] = await profile.get_displayname(target)
if "avatar_url" not in content:
content["avatar_url"] = await profile.get_avatar_url(target)
return await self.remote_knock(
remote_room_hosts, room_id, target, content
)
return await self._local_membership_update( return await self._local_membership_update(
requester=requester, requester=requester,
target=target, target=target,
@ -1209,6 +1285,35 @@ class RoomMemberMasterHandler(RoomMemberHandler):
invite_event, txn_id, requester, content invite_event, txn_id, requester, content
) )
async def remote_rescind_knock(
self,
knock_event_id: str,
txn_id: Optional[str],
requester: Requester,
content: JsonDict,
) -> Tuple[str, int]:
"""
Rescinds a local knock made on a remote room
Args:
knock_event_id: The ID of the knock event to rescind.
txn_id: The transaction ID to use.
requester: The originator of the request.
content: The content of the leave event.
Implements RoomMemberHandler.remote_rescind_knock
"""
# TODO: We don't yet support rescinding knocks over federation
# as we don't know which homeserver to send it to. An obvious
# candidate is the remote homeserver we originally knocked through,
# however we don't currently store that information.
# Just rescind the knock locally
knock_event = await self.store.get_event(knock_event_id)
return await self._generate_local_out_of_band_leave(
knock_event, txn_id, requester, content
)
async def _generate_local_out_of_band_leave( async def _generate_local_out_of_band_leave(
self, self,
previous_membership_event: EventBase, previous_membership_event: EventBase,
@ -1272,6 +1377,36 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return result_event.event_id, result_event.internal_metadata.stream_ordering return result_event.event_id, result_event.internal_metadata.stream_ordering
async def remote_knock(
self,
remote_room_hosts: List[str],
room_id: str,
user: UserID,
content: dict,
) -> Tuple[str, int]:
"""Sends a knock to a room. Attempts to do so via one remote out of a given list.
Args:
remote_room_hosts: A list of homeservers to try knocking through.
room_id: The ID of the room to knock on.
user: The user to knock on behalf of.
content: The content of the knock event.
Returns:
A tuple of (event ID, stream ID).
"""
# filter ourselves out of remote_room_hosts
remote_room_hosts = [
host for host in remote_room_hosts if host != self.hs.hostname
]
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
return await self.federation_handler.do_knock(
remote_room_hosts, room_id, user.to_string(), content=content
)
async def _user_left_room(self, target: UserID, room_id: str) -> None: async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room""" """Implements RoomMemberHandler._user_left_room"""
user_left_room(self.distributor, target, room_id) user_left_room(self.distributor, target, room_id)

View file

@ -1,4 +1,4 @@
# Copyright 2018 New Vector Ltd # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -19,10 +19,12 @@ from synapse.api.errors import SynapseError
from synapse.handlers.room_member import RoomMemberHandler from synapse.handlers.room_member import RoomMemberHandler
from synapse.replication.http.membership import ( from synapse.replication.http.membership import (
ReplicationRemoteJoinRestServlet as ReplRemoteJoin, ReplicationRemoteJoinRestServlet as ReplRemoteJoin,
ReplicationRemoteKnockRestServlet as ReplRemoteKnock,
ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite, ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite,
ReplicationRemoteRescindKnockRestServlet as ReplRescindKnock,
ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft, ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft,
) )
from synapse.types import Requester, UserID from synapse.types import JsonDict, Requester, UserID
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -35,7 +37,9 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
super().__init__(hs) super().__init__(hs)
self._remote_join_client = ReplRemoteJoin.make_client(hs) self._remote_join_client = ReplRemoteJoin.make_client(hs)
self._remote_knock_client = ReplRemoteKnock.make_client(hs)
self._remote_reject_client = ReplRejectInvite.make_client(hs) self._remote_reject_client = ReplRejectInvite.make_client(hs)
self._remote_rescind_client = ReplRescindKnock.make_client(hs)
self._notify_change_client = ReplJoinedLeft.make_client(hs) self._notify_change_client = ReplJoinedLeft.make_client(hs)
async def _remote_join( async def _remote_join(
@ -80,6 +84,53 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
) )
return ret["event_id"], ret["stream_id"] return ret["event_id"], ret["stream_id"]
async def remote_rescind_knock(
self,
knock_event_id: str,
txn_id: Optional[str],
requester: Requester,
content: JsonDict,
) -> Tuple[str, int]:
"""
Rescinds a local knock made on a remote room
Args:
knock_event_id: the knock event
txn_id: optional transaction ID supplied by the client
requester: user making the request, according to the access token
content: additional content to include in the leave event.
Normally an empty dict.
Returns:
A tuple containing (event_id, stream_id of the leave event)
"""
ret = await self._remote_rescind_client(
knock_event_id=knock_event_id,
txn_id=txn_id,
requester=requester,
content=content,
)
return ret["event_id"], ret["stream_id"]
async def remote_knock(
self,
remote_room_hosts: List[str],
room_id: str,
user: UserID,
content: dict,
) -> Tuple[str, int]:
"""Sends a knock to a room.
Implements RoomMemberHandler.remote_knock
"""
ret = await self._remote_knock_client(
remote_room_hosts=remote_room_hosts,
room_id=room_id,
user=user,
content=content,
)
return ret["event_id"], ret["stream_id"]
async def _user_left_room(self, target: UserID, room_id: str) -> None: async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room""" """Implements RoomMemberHandler._user_left_room"""
await self._notify_change_client( await self._notify_change_client(

View file

@ -1,4 +1,5 @@
# Copyright 2018 New Vector Ltd # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
# Copyright 2020 Sorunome
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -230,6 +231,8 @@ class StatsHandler:
room_stats_delta["left_members"] -= 1 room_stats_delta["left_members"] -= 1
elif prev_membership == Membership.BAN: elif prev_membership == Membership.BAN:
room_stats_delta["banned_members"] -= 1 room_stats_delta["banned_members"] -= 1
elif prev_membership == Membership.KNOCK:
room_stats_delta["knocked_members"] -= 1
else: else:
raise ValueError( raise ValueError(
"%r is not a valid prev_membership" % (prev_membership,) "%r is not a valid prev_membership" % (prev_membership,)
@ -251,6 +254,8 @@ class StatsHandler:
room_stats_delta["left_members"] += 1 room_stats_delta["left_members"] += 1
elif membership == Membership.BAN: elif membership == Membership.BAN:
room_stats_delta["banned_members"] += 1 room_stats_delta["banned_members"] += 1
elif membership == Membership.KNOCK:
room_stats_delta["knocked_members"] += 1
else: else:
raise ValueError("%r is not a valid membership" % (membership,)) raise ValueError("%r is not a valid membership" % (membership,))

View file

@ -159,6 +159,16 @@ class InvitedSyncResult:
return True return True
@attr.s(slots=True, frozen=True)
class KnockedSyncResult:
room_id = attr.ib(type=str)
knock = attr.ib(type=EventBase)
def __bool__(self) -> bool:
"""Knocked rooms should always be reported to the client"""
return True
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)
class GroupsSyncResult: class GroupsSyncResult:
join = attr.ib(type=JsonDict) join = attr.ib(type=JsonDict)
@ -192,6 +202,7 @@ class _RoomChanges:
room_entries = attr.ib(type=List["RoomSyncResultBuilder"]) room_entries = attr.ib(type=List["RoomSyncResultBuilder"])
invited = attr.ib(type=List[InvitedSyncResult]) invited = attr.ib(type=List[InvitedSyncResult])
knocked = attr.ib(type=List[KnockedSyncResult])
newly_joined_rooms = attr.ib(type=List[str]) newly_joined_rooms = attr.ib(type=List[str])
newly_left_rooms = attr.ib(type=List[str]) newly_left_rooms = attr.ib(type=List[str])
@ -205,6 +216,7 @@ class SyncResult:
account_data: List of account_data events for the user. account_data: List of account_data events for the user.
joined: JoinedSyncResult for each joined room. joined: JoinedSyncResult for each joined room.
invited: InvitedSyncResult for each invited room. invited: InvitedSyncResult for each invited room.
knocked: KnockedSyncResult for each knocked on room.
archived: ArchivedSyncResult for each archived room. archived: ArchivedSyncResult for each archived room.
to_device: List of direct messages for the device. to_device: List of direct messages for the device.
device_lists: List of user_ids whose devices have changed device_lists: List of user_ids whose devices have changed
@ -220,6 +232,7 @@ class SyncResult:
account_data = attr.ib(type=List[JsonDict]) account_data = attr.ib(type=List[JsonDict])
joined = attr.ib(type=List[JoinedSyncResult]) joined = attr.ib(type=List[JoinedSyncResult])
invited = attr.ib(type=List[InvitedSyncResult]) invited = attr.ib(type=List[InvitedSyncResult])
knocked = attr.ib(type=List[KnockedSyncResult])
archived = attr.ib(type=List[ArchivedSyncResult]) archived = attr.ib(type=List[ArchivedSyncResult])
to_device = attr.ib(type=List[JsonDict]) to_device = attr.ib(type=List[JsonDict])
device_lists = attr.ib(type=DeviceLists) device_lists = attr.ib(type=DeviceLists)
@ -236,6 +249,7 @@ class SyncResult:
self.presence self.presence
or self.joined or self.joined
or self.invited or self.invited
or self.knocked
or self.archived or self.archived
or self.account_data or self.account_data
or self.to_device or self.to_device
@ -1031,7 +1045,7 @@ class SyncHandler:
res = await self._generate_sync_entry_for_rooms( res = await self._generate_sync_entry_for_rooms(
sync_result_builder, account_data_by_room sync_result_builder, account_data_by_room
) )
newly_joined_rooms, newly_joined_or_invited_users, _, _ = res newly_joined_rooms, newly_joined_or_invited_or_knocked_users, _, _ = res
_, _, newly_left_rooms, newly_left_users = res _, _, newly_left_rooms, newly_left_users = res
block_all_presence_data = ( block_all_presence_data = (
@ -1040,7 +1054,9 @@ class SyncHandler:
if self.hs_config.use_presence and not block_all_presence_data: if self.hs_config.use_presence and not block_all_presence_data:
logger.debug("Fetching presence data") logger.debug("Fetching presence data")
await self._generate_sync_entry_for_presence( await self._generate_sync_entry_for_presence(
sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users sync_result_builder,
newly_joined_rooms,
newly_joined_or_invited_or_knocked_users,
) )
logger.debug("Fetching to-device data") logger.debug("Fetching to-device data")
@ -1049,7 +1065,7 @@ class SyncHandler:
device_lists = await self._generate_sync_entry_for_device_list( device_lists = await self._generate_sync_entry_for_device_list(
sync_result_builder, sync_result_builder,
newly_joined_rooms=newly_joined_rooms, newly_joined_rooms=newly_joined_rooms,
newly_joined_or_invited_users=newly_joined_or_invited_users, newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users,
newly_left_rooms=newly_left_rooms, newly_left_rooms=newly_left_rooms,
newly_left_users=newly_left_users, newly_left_users=newly_left_users,
) )
@ -1083,6 +1099,7 @@ class SyncHandler:
account_data=sync_result_builder.account_data, account_data=sync_result_builder.account_data,
joined=sync_result_builder.joined, joined=sync_result_builder.joined,
invited=sync_result_builder.invited, invited=sync_result_builder.invited,
knocked=sync_result_builder.knocked,
archived=sync_result_builder.archived, archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device, to_device=sync_result_builder.to_device,
device_lists=device_lists, device_lists=device_lists,
@ -1142,7 +1159,7 @@ class SyncHandler:
self, self,
sync_result_builder: "SyncResultBuilder", sync_result_builder: "SyncResultBuilder",
newly_joined_rooms: Set[str], newly_joined_rooms: Set[str],
newly_joined_or_invited_users: Set[str], newly_joined_or_invited_or_knocked_users: Set[str],
newly_left_rooms: Set[str], newly_left_rooms: Set[str],
newly_left_users: Set[str], newly_left_users: Set[str],
) -> DeviceLists: ) -> DeviceLists:
@ -1151,8 +1168,9 @@ class SyncHandler:
Args: Args:
sync_result_builder sync_result_builder
newly_joined_rooms: Set of rooms user has joined since previous sync newly_joined_rooms: Set of rooms user has joined since previous sync
newly_joined_or_invited_users: Set of users that have joined or newly_joined_or_invited_or_knocked_users: Set of users that have joined,
been invited to a room since previous sync. been invited to a room or are knocking on a room since
previous sync.
newly_left_rooms: Set of rooms user has left since previous sync newly_left_rooms: Set of rooms user has left since previous sync
newly_left_users: Set of users that have left a room we're in since newly_left_users: Set of users that have left a room we're in since
previous sync previous sync
@ -1163,7 +1181,9 @@ class SyncHandler:
# We're going to mutate these fields, so lets copy them rather than # We're going to mutate these fields, so lets copy them rather than
# assume they won't get used later. # assume they won't get used later.
newly_joined_or_invited_users = set(newly_joined_or_invited_users) newly_joined_or_invited_or_knocked_users = set(
newly_joined_or_invited_or_knocked_users
)
newly_left_users = set(newly_left_users) newly_left_users = set(newly_left_users)
if since_token and since_token.device_list_key: if since_token and since_token.device_list_key:
@ -1202,11 +1222,11 @@ class SyncHandler:
# Step 1b, check for newly joined rooms # Step 1b, check for newly joined rooms
for room_id in newly_joined_rooms: for room_id in newly_joined_rooms:
joined_users = await self.store.get_users_in_room(room_id) joined_users = await self.store.get_users_in_room(room_id)
newly_joined_or_invited_users.update(joined_users) newly_joined_or_invited_or_knocked_users.update(joined_users)
# TODO: Check that these users are actually new, i.e. either they # TODO: Check that these users are actually new, i.e. either they
# weren't in the previous sync *or* they left and rejoined. # weren't in the previous sync *or* they left and rejoined.
users_that_have_changed.update(newly_joined_or_invited_users) users_that_have_changed.update(newly_joined_or_invited_or_knocked_users)
user_signatures_changed = ( user_signatures_changed = (
await self.store.get_users_whose_signatures_changed( await self.store.get_users_whose_signatures_changed(
@ -1452,6 +1472,7 @@ class SyncHandler:
room_entries = room_changes.room_entries room_entries = room_changes.room_entries
invited = room_changes.invited invited = room_changes.invited
knocked = room_changes.knocked
newly_joined_rooms = room_changes.newly_joined_rooms newly_joined_rooms = room_changes.newly_joined_rooms
newly_left_rooms = room_changes.newly_left_rooms newly_left_rooms = room_changes.newly_left_rooms
@ -1472,9 +1493,10 @@ class SyncHandler:
await concurrently_execute(handle_room_entries, room_entries, 10) await concurrently_execute(handle_room_entries, room_entries, 10)
sync_result_builder.invited.extend(invited) sync_result_builder.invited.extend(invited)
sync_result_builder.knocked.extend(knocked)
# Now we want to get any newly joined or invited users # Now we want to get any newly joined, invited or knocking users
newly_joined_or_invited_users = set() newly_joined_or_invited_or_knocked_users = set()
newly_left_users = set() newly_left_users = set()
if since_token: if since_token:
for joined_sync in sync_result_builder.joined: for joined_sync in sync_result_builder.joined:
@ -1486,19 +1508,22 @@ class SyncHandler:
if ( if (
event.membership == Membership.JOIN event.membership == Membership.JOIN
or event.membership == Membership.INVITE or event.membership == Membership.INVITE
or event.membership == Membership.KNOCK
): ):
newly_joined_or_invited_users.add(event.state_key) newly_joined_or_invited_or_knocked_users.add(
event.state_key
)
else: else:
prev_content = event.unsigned.get("prev_content", {}) prev_content = event.unsigned.get("prev_content", {})
prev_membership = prev_content.get("membership", None) prev_membership = prev_content.get("membership", None)
if prev_membership == Membership.JOIN: if prev_membership == Membership.JOIN:
newly_left_users.add(event.state_key) newly_left_users.add(event.state_key)
newly_left_users -= newly_joined_or_invited_users newly_left_users -= newly_joined_or_invited_or_knocked_users
return ( return (
set(newly_joined_rooms), set(newly_joined_rooms),
newly_joined_or_invited_users, newly_joined_or_invited_or_knocked_users,
set(newly_left_rooms), set(newly_left_rooms),
newly_left_users, newly_left_users,
) )
@ -1553,6 +1578,7 @@ class SyncHandler:
newly_left_rooms = [] newly_left_rooms = []
room_entries = [] room_entries = []
invited = [] invited = []
knocked = []
for room_id, events in mem_change_events_by_room_id.items(): for room_id, events in mem_change_events_by_room_id.items():
logger.debug( logger.debug(
"Membership changes in %s: [%s]", "Membership changes in %s: [%s]",
@ -1632,9 +1658,17 @@ class SyncHandler:
should_invite = non_joins[-1].membership == Membership.INVITE should_invite = non_joins[-1].membership == Membership.INVITE
if should_invite: if should_invite:
if event.sender not in ignored_users: if event.sender not in ignored_users:
room_sync = InvitedSyncResult(room_id, invite=non_joins[-1]) invite_room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
if room_sync: if invite_room_sync:
invited.append(room_sync) invited.append(invite_room_sync)
# Only bother if our latest membership in the room is knock (and we haven't
# been accepted/rejected in the meantime).
should_knock = non_joins[-1].membership == Membership.KNOCK
if should_knock:
knock_room_sync = KnockedSyncResult(room_id, knock=non_joins[-1])
if knock_room_sync:
knocked.append(knock_room_sync)
# Always include leave/ban events. Just take the last one. # Always include leave/ban events. Just take the last one.
# TODO: How do we handle ban -> leave in same batch? # TODO: How do we handle ban -> leave in same batch?
@ -1738,7 +1772,13 @@ class SyncHandler:
) )
room_entries.append(entry) room_entries.append(entry)
return _RoomChanges(room_entries, invited, newly_joined_rooms, newly_left_rooms) return _RoomChanges(
room_entries,
invited,
knocked,
newly_joined_rooms,
newly_left_rooms,
)
async def _get_all_rooms( async def _get_all_rooms(
self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
@ -1758,6 +1798,7 @@ class SyncHandler:
membership_list = ( membership_list = (
Membership.INVITE, Membership.INVITE,
Membership.KNOCK,
Membership.JOIN, Membership.JOIN,
Membership.LEAVE, Membership.LEAVE,
Membership.BAN, Membership.BAN,
@ -1769,6 +1810,7 @@ class SyncHandler:
room_entries = [] room_entries = []
invited = [] invited = []
knocked = []
for event in room_list: for event in room_list:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
@ -1788,8 +1830,11 @@ class SyncHandler:
continue continue
invite = await self.store.get_event(event.event_id) invite = await self.store.get_event(event.event_id)
invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite)) invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite))
elif event.membership == Membership.KNOCK:
knock = await self.store.get_event(event.event_id)
knocked.append(KnockedSyncResult(room_id=event.room_id, knock=knock))
elif event.membership in (Membership.LEAVE, Membership.BAN): elif event.membership in (Membership.LEAVE, Membership.BAN):
# Always send down rooms we were banned or kicked from. # Always send down rooms we were banned from or kicked from.
if not sync_config.filter_collection.include_leave: if not sync_config.filter_collection.include_leave:
if event.membership == Membership.LEAVE: if event.membership == Membership.LEAVE:
if user_id == event.sender: if user_id == event.sender:
@ -1810,7 +1855,7 @@ class SyncHandler:
) )
) )
return _RoomChanges(room_entries, invited, [], []) return _RoomChanges(room_entries, invited, knocked, [], [])
async def _generate_room_entry( async def _generate_room_entry(
self, self,
@ -2101,6 +2146,7 @@ class SyncResultBuilder:
account_data (list) account_data (list)
joined (list[JoinedSyncResult]) joined (list[JoinedSyncResult])
invited (list[InvitedSyncResult]) invited (list[InvitedSyncResult])
knocked (list[KnockedSyncResult])
archived (list[ArchivedSyncResult]) archived (list[ArchivedSyncResult])
groups (GroupsSyncResult|None) groups (GroupsSyncResult|None)
to_device (list) to_device (list)
@ -2116,6 +2162,7 @@ class SyncResultBuilder:
account_data = attr.ib(type=List[JsonDict], default=attr.Factory(list)) account_data = attr.ib(type=List[JsonDict], default=attr.Factory(list))
joined = attr.ib(type=List[JoinedSyncResult], default=attr.Factory(list)) joined = attr.ib(type=List[JoinedSyncResult], default=attr.Factory(list))
invited = attr.ib(type=List[InvitedSyncResult], default=attr.Factory(list)) invited = attr.ib(type=List[InvitedSyncResult], default=attr.Factory(list))
knocked = attr.ib(type=List[KnockedSyncResult], default=attr.Factory(list))
archived = attr.ib(type=List[ArchivedSyncResult], default=attr.Factory(list)) archived = attr.ib(type=List[ArchivedSyncResult], default=attr.Factory(list))
groups = attr.ib(type=Optional[GroupsSyncResult], default=None) groups = attr.ib(type=Optional[GroupsSyncResult], default=None)
to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list)) to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list))

View file

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
""" This module contains base REST classes for constructing REST servlets. """ """ This module contains base REST classes for constructing REST servlets. """
import logging import logging
from typing import Dict, Iterable, List, Optional, overload from typing import Dict, Iterable, List, Optional, overload

View file

@ -97,6 +97,76 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
return 200, {"event_id": event_id, "stream_id": stream_id} return 200, {"event_id": event_id, "stream_id": stream_id}
class ReplicationRemoteKnockRestServlet(ReplicationEndpoint):
"""Perform a remote knock for the given user on the given room
Request format:
POST /_synapse/replication/remote_knock/:room_id/:user_id
{
"requester": ...,
"remote_room_hosts": [...],
"content": { ... }
}
"""
NAME = "remote_knock"
PATH_ARGS = ("room_id", "user_id")
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.federation_handler = hs.get_federation_handler()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload( # type: ignore
requester: Requester,
room_id: str,
user_id: str,
remote_room_hosts: List[str],
content: JsonDict,
):
"""
Args:
requester: The user making the request, according to the access token.
room_id: The ID of the room to knock on.
user_id: The ID of the knocking user.
remote_room_hosts: Servers to try and send the knock via.
content: The event content to use for the knock event.
"""
return {
"requester": requester.serialize(),
"remote_room_hosts": remote_room_hosts,
"content": content,
}
async def _handle_request( # type: ignore
self,
request: SynapseRequest,
room_id: str,
user_id: str,
):
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"])
request.requester = requester
logger.debug("remote_knock: %s on room: %s", user_id, room_id)
event_id, stream_id = await self.federation_handler.do_knock(
remote_room_hosts, room_id, user_id, event_content
)
return 200, {"event_id": event_id, "stream_id": stream_id}
class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
"""Rejects an out-of-band invite we have received from a remote server """Rejects an out-of-band invite we have received from a remote server
@ -167,6 +237,75 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
return 200, {"event_id": event_id, "stream_id": stream_id} return 200, {"event_id": event_id, "stream_id": stream_id}
class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint):
"""Rescinds a local knock made on a remote room
Request format:
POST /_synapse/replication/remote_rescind_knock/:event_id
{
"txn_id": ...,
"requester": ...,
"content": { ... }
}
"""
NAME = "remote_rescind_knock"
PATH_ARGS = ("knock_event_id",)
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.member_handler = hs.get_room_member_handler()
@staticmethod
async def _serialize_payload( # type: ignore
knock_event_id: str,
txn_id: Optional[str],
requester: Requester,
content: JsonDict,
):
"""
Args:
knock_event_id: The ID of the knock to be rescinded.
txn_id: An optional transaction ID supplied by the client.
requester: The user making the rescind request, according to the access token.
content: The content to include in the rescind event.
"""
return {
"txn_id": txn_id,
"requester": requester.serialize(),
"content": content,
}
async def _handle_request( # type: ignore
self,
request: SynapseRequest,
knock_event_id: str,
):
content = parse_json_object_from_request(request)
txn_id = content["txn_id"]
event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"])
request.requester = requester
# hopefully we're now on the master, so this won't recurse!
event_id, stream_id = await self.member_handler.remote_rescind_knock(
knock_event_id,
txn_id,
requester,
event_content,
)
return 200, {"event_id": event_id, "stream_id": stream_id}
class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
"""Notifies that a user has joined or left the room """Notifies that a user has joined or left the room

View file

@ -38,6 +38,7 @@ from synapse.rest.client.v2_alpha import (
filter, filter,
groups, groups,
keys, keys,
knock,
notifications, notifications,
openid, openid,
password_policy, password_policy,
@ -121,6 +122,10 @@ class ClientRestResource(JsonResource):
relations.register_servlets(hs, client_resource) relations.register_servlets(hs, client_resource)
password_policy.register_servlets(hs, client_resource) password_policy.register_servlets(hs, client_resource)
# Register msc2403 (knocking) servlets if the feature is enabled
if hs.config.experimental.msc2403_enabled:
knock.register_servlets(hs, client_resource)
# moving to /_synapse/admin # moving to /_synapse/admin
admin.register_servlets_for_client_rest_resource(hs, client_resource) admin.register_servlets_for_client_rest_resource(hs, client_resource)

View file

@ -14,10 +14,9 @@
# limitations under the License. # limitations under the License.
""" This module contains REST servlets to do with rooms: /rooms/<paths> """ """ This module contains REST servlets to do with rooms: /rooms/<paths> """
import logging import logging
import re import re
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from urllib import parse as urlparse from urllib import parse as urlparse
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
@ -38,6 +37,7 @@ from synapse.http.servlet import (
parse_integer, parse_integer,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
parse_strings_from_args,
) )
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import set_tag from synapse.logging.opentracing import set_tag
@ -278,7 +278,12 @@ class JoinRoomAliasServlet(TransactionRestServlet):
PATTERNS = "/join/(?P<room_identifier>[^/]*)" PATTERNS = "/join/(?P<room_identifier>[^/]*)"
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
async def on_POST(self, request, room_identifier, txn_id=None): async def on_POST(
self,
request: SynapseRequest,
room_identifier: str,
txn_id: Optional[str] = None,
):
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
try: try:
@ -290,17 +295,18 @@ class JoinRoomAliasServlet(TransactionRestServlet):
if RoomID.is_valid(room_identifier): if RoomID.is_valid(room_identifier):
room_id = room_identifier room_id = room_identifier
try:
remote_room_hosts = [ # twisted.web.server.Request.args is incorrectly defined as Optional[Any]
x.decode("ascii") for x in request.args[b"server_name"] args: Dict[bytes, List[bytes]] = request.args # type: ignore
] # type: Optional[List[str]]
except Exception: remote_room_hosts = parse_strings_from_args(
remote_room_hosts = None args, "server_name", required=False
)
elif RoomAlias.is_valid(room_identifier): elif RoomAlias.is_valid(room_identifier):
handler = self.room_member_handler handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier) room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias) room_id_obj, remote_room_hosts = await handler.lookup_room_alias(room_alias)
room_id = room_id.to_string() room_id = room_id_obj.to_string()
else: else:
raise SynapseError( raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,) 400, "%s was not legal room ID or room alias" % (room_identifier,)

View file

@ -0,0 +1,109 @@
# Copyright 2020 Sorunome
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from twisted.web.server import Request
from synapse.api.constants import Membership
from synapse.api.errors import SynapseError
from synapse.http.servlet import (
RestServlet,
parse_json_object_from_request,
parse_strings_from_args,
)
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import set_tag
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import JsonDict, RoomAlias, RoomID
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
from ._base import client_patterns
logger = logging.getLogger(__name__)
class KnockRoomAliasServlet(RestServlet):
"""
POST /xyz.amorgan.knock/{roomIdOrAlias}
"""
PATTERNS = client_patterns(
"/xyz.amorgan.knock/(?P<room_identifier>[^/]*)", releases=()
)
def __init__(self, hs: "HomeServer"):
super().__init__()
self.txns = HttpTransactionCache(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
async def on_POST(
self,
request: SynapseRequest,
room_identifier: str,
txn_id: Optional[str] = None,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
event_content = None
if "reason" in content:
event_content = {"reason": content["reason"]}
if RoomID.is_valid(room_identifier):
room_id = room_identifier
# twisted.web.server.Request.args is incorrectly defined as Optional[Any]
args: Dict[bytes, List[bytes]] = request.args # type: ignore
remote_room_hosts = parse_strings_from_args(
args, "server_name", required=False
)
elif RoomAlias.is_valid(room_identifier):
handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier)
room_id_obj, remote_room_hosts = await handler.lookup_room_alias(room_alias)
room_id = room_id_obj.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
await self.room_member_handler.update_membership(
requester=requester,
target=requester.user,
room_id=room_id,
action=Membership.KNOCK,
txn_id=txn_id,
third_party_signed=None,
remote_room_hosts=remote_room_hosts,
content=event_content,
)
return 200, {"room_id": room_id}
def on_PUT(self, request: Request, room_identifier: str, txn_id: str):
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_identifier, txn_id
)
def register_servlets(hs, http_server):
KnockRoomAliasServlet(hs).register(http_server)

View file

@ -11,12 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import itertools import itertools
import logging import logging
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple
from synapse.api.constants import PresenceState from synapse.api.constants import Membership, PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.events.utils import ( from synapse.events.utils import (
@ -24,7 +23,7 @@ from synapse.events.utils import (
format_event_raw, format_event_raw,
) )
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.sync import SyncConfig from synapse.handlers.sync import KnockedSyncResult, SyncConfig
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, StreamToken from synapse.types import JsonDict, StreamToken
@ -220,6 +219,10 @@ class SyncRestServlet(RestServlet):
sync_result.invited, time_now, access_token_id, event_formatter sync_result.invited, time_now, access_token_id, event_formatter
) )
knocked = await self.encode_knocked(
sync_result.knocked, time_now, access_token_id, event_formatter
)
archived = await self.encode_archived( archived = await self.encode_archived(
sync_result.archived, sync_result.archived,
time_now, time_now,
@ -237,11 +240,16 @@ class SyncRestServlet(RestServlet):
"left": list(sync_result.device_lists.left), "left": list(sync_result.device_lists.left),
}, },
"presence": SyncRestServlet.encode_presence(sync_result.presence, time_now), "presence": SyncRestServlet.encode_presence(sync_result.presence, time_now),
"rooms": {"join": joined, "invite": invited, "leave": archived}, "rooms": {
Membership.JOIN: joined,
Membership.INVITE: invited,
Membership.KNOCK: knocked,
Membership.LEAVE: archived,
},
"groups": { "groups": {
"join": sync_result.groups.join, Membership.JOIN: sync_result.groups.join,
"invite": sync_result.groups.invite, Membership.INVITE: sync_result.groups.invite,
"leave": sync_result.groups.leave, Membership.LEAVE: sync_result.groups.leave,
}, },
"device_one_time_keys_count": sync_result.device_one_time_keys_count, "device_one_time_keys_count": sync_result.device_one_time_keys_count,
"org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types, "org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types,
@ -303,7 +311,7 @@ class SyncRestServlet(RestServlet):
Args: Args:
rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of
sync results for rooms this user is joined to sync results for rooms this user is invited to
time_now(int): current time - used as a baseline for age time_now(int): current time - used as a baseline for age
calculations calculations
token_id(int): ID of the user's auth token - used for namespacing token_id(int): ID of the user's auth token - used for namespacing
@ -322,7 +330,7 @@ class SyncRestServlet(RestServlet):
time_now, time_now,
token_id=token_id, token_id=token_id,
event_format=event_formatter, event_format=event_formatter,
is_invite=True, include_stripped_room_state=True,
) )
unsigned = dict(invite.get("unsigned", {})) unsigned = dict(invite.get("unsigned", {}))
invite["unsigned"] = unsigned invite["unsigned"] = unsigned
@ -332,6 +340,60 @@ class SyncRestServlet(RestServlet):
return invited return invited
async def encode_knocked(
self,
rooms: List[KnockedSyncResult],
time_now: int,
token_id: int,
event_formatter: Callable[[Dict], Dict],
) -> Dict[str, Dict[str, Any]]:
"""
Encode the rooms we've knocked on in a sync result.
Args:
rooms: list of sync results for rooms this user is knocking on
time_now: current time - used as a baseline for age calculations
token_id: ID of the user's auth token - used for namespacing of transaction IDs
event_formatter: function to convert from federation format to client format
Returns:
The list of rooms the user has knocked on, in our response format.
"""
knocked = {}
for room in rooms:
knock = await self._event_serializer.serialize_event(
room.knock,
time_now,
token_id=token_id,
event_format=event_formatter,
include_stripped_room_state=True,
)
# Extract the `unsigned` key from the knock event.
# This is where we (cheekily) store the knock state events
unsigned = knock.setdefault("unsigned", {})
# Duplicate the dictionary in order to avoid modifying the original
unsigned = dict(unsigned)
# Extract the stripped room state from the unsigned dict
# This is for clients to get a little bit of information about
# the room they've knocked on, without revealing any sensitive information
knocked_state = list(unsigned.pop("knock_room_state", []))
# Append the actual knock membership event itself as well. This provides
# the client with:
#
# * A knock state event that they can use for easier internal tracking
# * The rough timestamp of when the knock occurred contained within the event
knocked_state.append(knock)
# Build the `knock_state` dictionary, which will contain the state of the
# room that the client has knocked on
knocked[room.room_id] = {"knock_state": {"events": knocked_state}}
return knocked
async def encode_archived( async def encode_archived(
self, rooms, time_now, token_id, event_fields, event_formatter self, rooms, time_now, token_id, event_fields, event_formatter
): ):

View file

@ -41,6 +41,7 @@ ABSOLUTE_STATS_FIELDS = {
"current_state_events", "current_state_events",
"joined_members", "joined_members",
"invited_members", "invited_members",
"knocked_members",
"left_members", "left_members",
"banned_members", "banned_members",
"local_users_in_room", "local_users_in_room",

View file

@ -0,0 +1,17 @@
/* Copyright 2020 Sorunome
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE room_stats_current ADD COLUMN knocked_members INT NOT NULL DEFAULT '0';
ALTER TABLE room_stats_historical ADD COLUMN knocked_members BIGINT NOT NULL DEFAULT '0';

View file

@ -0,0 +1,302 @@
# Copyright 2020 Matrix.org Federation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from typing import Dict, List
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.room_versions import RoomVersions
from synapse.events import builder
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from synapse.server import HomeServer
from synapse.types import RoomAlias
from tests.test_utils import event_injection
from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config
# An identifier to use while MSC2304 is not in a stable release of the spec
KNOCK_UNSTABLE_IDENTIFIER = "xyz.amorgan.knock"
class KnockingStrippedStateEventHelperMixin(TestCase):
def send_example_state_events_to_room(
self,
hs: "HomeServer",
room_id: str,
sender: str,
) -> OrderedDict:
"""Adds some state to a room. State events are those that should be sent to a knocking
user after they knock on the room, as well as some state that *shouldn't* be sent
to the knocking user.
Args:
hs: The homeserver of the sender.
room_id: The ID of the room to send state into.
sender: The ID of the user to send state as. Must be in the room.
Returns:
The OrderedDict of event types and content that a user is expected to see
after knocking on a room.
"""
# To set a canonical alias, we'll need to point an alias at the room first.
canonical_alias = "#fancy_alias:test"
self.get_success(
self.store.create_room_alias_association(
RoomAlias.from_string(canonical_alias), room_id, ["test"]
)
)
# Send some state that we *don't* expect to be given to knocking users
self.get_success(
event_injection.inject_event(
hs,
room_version=RoomVersions.MSC2403.identifier,
room_id=room_id,
sender=sender,
type="com.example.secret",
state_key="",
content={"secret": "password"},
)
)
# We use an OrderedDict here to ensure that the knock membership appears last.
# Note that order only matters when sending stripped state to clients, not federated
# homeservers.
room_state = OrderedDict(
[
# We need to set the room's join rules to allow knocking
(
EventTypes.JoinRules,
{"content": {"join_rule": JoinRules.KNOCK}, "state_key": ""},
),
# Below are state events that are to be stripped and sent to clients
(
EventTypes.Name,
{"content": {"name": "A cool room"}, "state_key": ""},
),
(
EventTypes.RoomAvatar,
{
"content": {
"info": {
"h": 398,
"mimetype": "image/jpeg",
"size": 31037,
"w": 394,
},
"url": "mxc://example.org/JWEIFJgwEIhweiWJE",
},
"state_key": "",
},
),
(
EventTypes.RoomEncryption,
{"content": {"algorithm": "m.megolm.v1.aes-sha2"}, "state_key": ""},
),
(
EventTypes.CanonicalAlias,
{
"content": {"alias": canonical_alias, "alt_aliases": []},
"state_key": "",
},
),
]
)
for event_type, event_dict in room_state.items():
event_content = event_dict["content"]
state_key = event_dict["state_key"]
self.get_success(
event_injection.inject_event(
hs,
room_version=RoomVersions.MSC2403.identifier,
room_id=room_id,
sender=sender,
type=event_type,
state_key=state_key,
content=event_content,
)
)
# Finally, we expect to see the m.room.create event of the room as part of the
# stripped state. We don't need to inject this event though.
room_state[EventTypes.Create] = {
"content": {
"creator": sender,
"room_version": RoomVersions.MSC2403.identifier,
},
"state_key": "",
}
return room_state
def check_knock_room_state_against_room_state(
self,
knock_room_state: List[Dict],
expected_room_state: Dict,
) -> None:
"""Test a list of stripped room state events received over federation against a
dict of expected state events.
Args:
knock_room_state: The list of room state that was received over federation.
expected_room_state: A dict containing the room state we expect to see in
`knock_room_state`.
"""
for event in knock_room_state:
event_type = event["type"]
# Check that this event type is one of those that we expected.
# Note: This will also check that no excess state was included
self.assertIn(event_type, expected_room_state)
# Check the state content matches
self.assertEquals(
expected_room_state[event_type]["content"], event["content"]
)
# Check the state key is correct
self.assertEqual(
expected_room_state[event_type]["state_key"], event["state_key"]
)
# Ensure the event has been stripped
self.assertNotIn("signatures", event)
# Pop once we've found and processed a state event
expected_room_state.pop(event_type)
# Check that all expected state events were accounted for
self.assertEqual(len(expected_room_state), 0)
class FederationKnockingTestCase(
FederatingHomeserverTestCase, KnockingStrippedStateEventHelperMixin
):
servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]
def prepare(self, reactor, clock, homeserver):
self.store = homeserver.get_datastore()
# We're not going to be properly signing events as our remote homeserver is fake,
# therefore disable event signature checks.
# Note that these checks are not relevant to this test case.
# Have this homeserver auto-approve all event signature checking.
async def approve_all_signature_checking(_, pdu):
return pdu
homeserver.get_federation_server()._check_sigs_and_hash = (
approve_all_signature_checking
)
# Have this homeserver skip event auth checks. This is necessary due to
# event auth checks ensuring that events were signed by the sender's homeserver.
async def _check_event_auth(
origin, event, context, state, auth_events, backfilled
):
return context
homeserver.get_federation_handler()._check_event_auth = _check_event_auth
return super().prepare(reactor, clock, homeserver)
@override_config({"experimental_features": {"msc2403_enabled": True}})
def test_room_state_returned_when_knocking(self):
"""
Tests that specific, stripped state events from a room are returned after
a remote homeserver successfully knocks on a local room.
"""
user_id = self.register_user("u1", "you the one")
user_token = self.login("u1", "you the one")
fake_knocking_user_id = "@user:other.example.com"
# Create a room with a room version that includes knocking
room_id = self.helper.create_room_as(
"u1",
is_public=False,
room_version=RoomVersions.MSC2403.identifier,
tok=user_token,
)
# Update the join rules and add additional state to the room to check for later
expected_room_state = self.send_example_state_events_to_room(
self.hs, room_id, user_id
)
channel = self.make_request(
"GET",
"/_matrix/federation/unstable/%s/make_knock/%s/%s?ver=%s"
% (
KNOCK_UNSTABLE_IDENTIFIER,
room_id,
fake_knocking_user_id,
# Inform the remote that we support the room version of the room we're
# knocking on
RoomVersions.MSC2403.identifier,
),
)
self.assertEquals(200, channel.code, channel.result)
# Note: We don't expect the knock membership event to be sent over federation as
# part of the stripped room state, as the knocking homeserver already has that
# event. It is only done for clients during /sync
# Extract the generated knock event json
knock_event = channel.json_body["event"]
# Check that the event has things we expect in it
self.assertEquals(knock_event["room_id"], room_id)
self.assertEquals(knock_event["sender"], fake_knocking_user_id)
self.assertEquals(knock_event["state_key"], fake_knocking_user_id)
self.assertEquals(knock_event["type"], EventTypes.Member)
self.assertEquals(knock_event["content"]["membership"], Membership.KNOCK)
# Turn the event json dict into a proper event.
# We won't sign it properly, but that's OK as we stub out event auth in `prepare`
signed_knock_event = builder.create_local_event_from_event_dict(
self.clock,
self.hs.hostname,
self.hs.signing_key,
room_version=RoomVersions.MSC2403,
event_dict=knock_event,
)
# Convert our proper event back to json dict format
signed_knock_event_json = signed_knock_event.get_pdu_json(
self.clock.time_msec()
)
# Send the signed knock event into the room
channel = self.make_request(
"PUT",
"/_matrix/federation/unstable/%s/send_knock/%s/%s"
% (KNOCK_UNSTABLE_IDENTIFIER, room_id, signed_knock_event.event_id),
signed_knock_event_json,
)
self.assertEquals(200, channel.code, channel.result)
# Check that we got the stripped room state in return
room_state_events = channel.json_body["knock_state_events"]
# Validate the stripped room state events
self.check_knock_room_state_against_room_state(
room_state_events, expected_room_state
)

View file

@ -17,10 +17,14 @@ import json
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.rest.client.v1 import login, room from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import read_marker, sync from synapse.rest.client.v2_alpha import knock, read_marker, sync
from tests import unittest from tests import unittest
from tests.federation.transport.test_knocking import (
KnockingStrippedStateEventHelperMixin,
)
from tests.server import TimedOutException from tests.server import TimedOutException
from tests.unittest import override_config
class FilterTestCase(unittest.HomeserverTestCase): class FilterTestCase(unittest.HomeserverTestCase):
@ -305,6 +309,93 @@ class SyncTypingTests(unittest.HomeserverTestCase):
self.make_request("GET", sync_url % (access_token, next_batch)) self.make_request("GET", sync_url % (access_token, next_batch))
class SyncKnockTestCase(
unittest.HomeserverTestCase, KnockingStrippedStateEventHelperMixin
):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
room.register_servlets,
sync.register_servlets,
knock.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.url = "/sync?since=%s"
self.next_batch = "s0"
# Register the first user (used to create the room to knock on).
self.user_id = self.register_user("kermit", "monkey")
self.tok = self.login("kermit", "monkey")
# Create the room we'll knock on.
self.room_id = self.helper.create_room_as(
self.user_id,
is_public=False,
room_version="xyz.amorgan.knock",
tok=self.tok,
)
# Register the second user (used to knock on the room).
self.knocker = self.register_user("knocker", "monkey")
self.knocker_tok = self.login("knocker", "monkey")
# Perform an initial sync for the knocking user.
channel = self.make_request(
"GET",
self.url % self.next_batch,
access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Store the next batch for the next request.
self.next_batch = channel.json_body["next_batch"]
# Set up some room state to test with.
self.expected_room_state = self.send_example_state_events_to_room(
hs, self.room_id, self.user_id
)
@override_config({"experimental_features": {"msc2403_enabled": True}})
def test_knock_room_state(self):
"""Tests that /sync returns state from a room after knocking on it."""
# Knock on a room
channel = self.make_request(
"POST",
"/_matrix/client/unstable/xyz.amorgan.knock/%s" % (self.room_id,),
b"{}",
self.knocker_tok,
)
self.assertEquals(200, channel.code, channel.result)
# We expect to see the knock event in the stripped room state later
self.expected_room_state[EventTypes.Member] = {
"content": {"membership": "xyz.amorgan.knock", "displayname": "knocker"},
"state_key": "@knocker:test",
}
# Check that /sync includes stripped state from the room
channel = self.make_request(
"GET",
self.url % self.next_batch,
access_token=self.knocker_tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Extract the stripped room state events from /sync
knock_entry = channel.json_body["rooms"]["xyz.amorgan.knock"]
room_state_events = knock_entry[self.room_id]["knock_state"]["events"]
# Validate that the knock membership event came last
self.assertEqual(room_state_events[-1]["type"], EventTypes.Member)
# Validate the stripped room state events
self.check_knock_room_state_against_room_state(
room_state_events, self.expected_room_state
)
class UnreadMessagesTestCase(unittest.HomeserverTestCase): class UnreadMessagesTestCase(unittest.HomeserverTestCase):
servlets = [ servlets = [
synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets,
@ -447,7 +538,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
) )
self._check_unread_count(5) self._check_unread_count(5)
def _check_unread_count(self, expected_count: True): def _check_unread_count(self, expected_count: int):
"""Syncs and compares the unread count with the expected value.""" """Syncs and compares the unread count with the expected value."""
channel = self.make_request( channel = self.make_request(