Merge pull request #6840 from matrix-org/rav/federation_client_async

Port much of `synapse.federation.federation_client` to async/await
This commit is contained in:
Richard van der Hoff 2020-02-05 16:56:39 +00:00 committed by GitHub
commit 577f460369
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 143 additions and 136 deletions

1
changelog.d/6840.misc Normal file
View file

@ -0,0 +1 @@
Port much of `synapse.handlers.federation` to async/await.

View file

@ -17,7 +17,18 @@
import copy import copy
import itertools import itertools
import logging import logging
from typing import Dict, Iterable from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Tuple,
TypeVar,
)
from prometheus_client import Counter from prometheus_client import Counter
@ -35,12 +46,14 @@ from synapse.api.errors import (
from synapse.api.room_versions import ( from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS, KNOWN_ROOM_VERSIONS,
EventFormatVersions, EventFormatVersions,
RoomVersion,
RoomVersions, RoomVersions,
) )
from synapse.events import builder, room_version_to_event_format from synapse.events import EventBase, builder, room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
from synapse.types import JsonDict
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
@ -52,6 +65,8 @@ sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["t
PDU_RETRY_TIME_MS = 1 * 60 * 1000 PDU_RETRY_TIME_MS = 1 * 60 * 1000
T = TypeVar("T")
class InvalidResponseError(RuntimeError): class InvalidResponseError(RuntimeError):
"""Helper for _try_destination_list: indicates that the server returned a response """Helper for _try_destination_list: indicates that the server returned a response
@ -170,21 +185,17 @@ class FederationClient(FederationBase):
sent_queries_counter.labels("client_one_time_keys").inc() sent_queries_counter.labels("client_one_time_keys").inc()
return self.transport_layer.claim_client_keys(destination, content, timeout) return self.transport_layer.claim_client_keys(destination, content, timeout)
@defer.inlineCallbacks async def backfill(
@log_function self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
def backfill(self, dest, room_id, limit, extremities): ) -> List[EventBase]:
"""Requests some more historic PDUs for the given context from the """Requests some more historic PDUs for the given room from the
given destination server. given destination server.
Args: Args:
dest (str): The remote homeserver to ask. dest (str): The remote homeserver to ask.
room_id (str): The room_id to backfill. room_id (str): The room_id to backfill.
limit (int): The maximum number of PDUs to return. limit (int): The maximum number of events to return.
extremities (list): List of PDU id and origins of the first pdus extremities (list): our current backwards extremities, to backfill from
we have seen from the context
Returns:
Deferred: Results in the received PDUs.
""" """
logger.debug("backfill extrem=%s", extremities) logger.debug("backfill extrem=%s", extremities)
@ -192,13 +203,13 @@ class FederationClient(FederationBase):
if not extremities: if not extremities:
return return
transaction_data = yield self.transport_layer.backfill( transaction_data = await self.transport_layer.backfill(
dest, room_id, extremities, limit dest, room_id, extremities, limit
) )
logger.debug("backfill transaction_data=%r", transaction_data) logger.debug("backfill transaction_data=%r", transaction_data)
room_version = yield self.store.get_room_version_id(room_id) room_version = await self.store.get_room_version_id(room_id)
format_ver = room_version_to_event_format(room_version) format_ver = room_version_to_event_format(room_version)
pdus = [ pdus = [
@ -207,7 +218,7 @@ class FederationClient(FederationBase):
] ]
# FIXME: We should handle signature failures more gracefully. # FIXME: We should handle signature failures more gracefully.
pdus[:] = yield make_deferred_yieldable( pdus[:] = await make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
@ -215,11 +226,14 @@ class FederationClient(FederationBase):
return pdus return pdus
@defer.inlineCallbacks async def get_pdu(
@log_function self,
def get_pdu( destinations: Iterable[str],
self, destinations, event_id, room_version, outlier=False, timeout=None event_id: str,
): room_version: str,
outlier: bool = False,
timeout: Optional[int] = None,
) -> Optional[EventBase]:
"""Requests the PDU with given origin and ID from the remote home """Requests the PDU with given origin and ID from the remote home
servers. servers.
@ -227,18 +241,17 @@ class FederationClient(FederationBase):
one succeeds. one succeeds.
Args: Args:
destinations (list): Which homeservers to query destinations: Which homeservers to query
event_id (str): event to fetch event_id: event to fetch
room_version (str): version of the room room_version: version of the room
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if outlier: Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False` of the current block of PDUs. Defaults to `False`
timeout (int): How long to try (in ms) each destination for before timeout: How long to try (in ms) each destination for before
moving to the next destination. None indicates no timeout. moving to the next destination. None indicates no timeout.
Returns: Returns:
Deferred: Results in the requested PDU, or None if we were unable to find The requested PDU, or None if we were unable to find it.
it.
""" """
# TODO: Rate limit the number of times we try and get the same event. # TODO: Rate limit the number of times we try and get the same event.
@ -259,7 +272,7 @@ class FederationClient(FederationBase):
continue continue
try: try:
transaction_data = yield self.transport_layer.get_event( transaction_data = await self.transport_layer.get_event(
destination, event_id, timeout=timeout destination, event_id, timeout=timeout
) )
@ -279,7 +292,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0] pdu = pdu_list[0]
# Check signatures are correct. # Check signatures are correct.
signed_pdu = yield self._check_sigs_and_hash(room_version, pdu) signed_pdu = await self._check_sigs_and_hash(room_version, pdu)
break break
@ -309,15 +322,16 @@ class FederationClient(FederationBase):
return signed_pdu return signed_pdu
@defer.inlineCallbacks async def get_room_state_ids(
def get_room_state_ids(self, destination: str, room_id: str, event_id: str): self, destination: str, room_id: str, event_id: str
) -> Tuple[List[str], List[str]]:
"""Calls the /state_ids endpoint to fetch the state at a particular point """Calls the /state_ids endpoint to fetch the state at a particular point
in the room, and the auth events for the given event in the room, and the auth events for the given event
Returns: Returns:
Tuple[List[str], List[str]]: a tuple of (state event_ids, auth event_ids) a tuple of (state event_ids, auth event_ids)
""" """
result = yield self.transport_layer.get_room_state_ids( result = await self.transport_layer.get_room_state_ids(
destination, room_id, event_id=event_id destination, room_id, event_id=event_id
) )
@ -331,19 +345,17 @@ class FederationClient(FederationBase):
return state_event_ids, auth_event_ids return state_event_ids, auth_event_ids
@defer.inlineCallbacks async def get_event_auth(self, destination, room_id, event_id):
@log_function res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
def get_event_auth(self, destination, room_id, event_id):
res = yield self.transport_layer.get_event_auth(destination, room_id, event_id)
room_version = yield self.store.get_room_version_id(room_id) room_version = await self.store.get_room_version_id(room_id)
format_ver = room_version_to_event_format(room_version) format_ver = room_version_to_event_format(room_version)
auth_chain = [ auth_chain = [
event_from_pdu_json(p, format_ver, outlier=True) for p in res["auth_chain"] event_from_pdu_json(p, format_ver, outlier=True) for p in res["auth_chain"]
] ]
signed_auth = yield self._check_sigs_and_hash_and_fetch( signed_auth = await self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True, room_version=room_version destination, auth_chain, outlier=True, room_version=room_version
) )
@ -351,17 +363,21 @@ class FederationClient(FederationBase):
return signed_auth return signed_auth
@defer.inlineCallbacks async def _try_destination_list(
def _try_destination_list(self, description, destinations, callback): self,
description: str,
destinations: Iterable[str],
callback: Callable[[str], Awaitable[T]],
) -> T:
"""Try an operation on a series of servers, until it succeeds """Try an operation on a series of servers, until it succeeds
Args: Args:
description (unicode): description of the operation we're doing, for logging description: description of the operation we're doing, for logging
destinations (Iterable[unicode]): list of server_names to try destinations: list of server_names to try
callback (callable): Function to run for each server. Passed a single callback: Function to run for each server. Passed a single
argument: the server_name to try. May return a deferred. argument: the server_name to try.
If the callback raises a CodeMessageException with a 300/400 code, If the callback raises a CodeMessageException with a 300/400 code,
attempts to perform the operation stop immediately and the exception is attempts to perform the operation stop immediately and the exception is
@ -372,7 +388,7 @@ class FederationClient(FederationBase):
suppressed if the exception is an InvalidResponseError. suppressed if the exception is an InvalidResponseError.
Returns: Returns:
The [Deferred] result of callback, if it succeeds The result of callback, if it succeeds
Raises: Raises:
SynapseError if the chosen remote server returns a 300/400 code, or SynapseError if the chosen remote server returns a 300/400 code, or
@ -383,7 +399,7 @@ class FederationClient(FederationBase):
continue continue
try: try:
res = yield callback(destination) res = await callback(destination)
return res return res
except InvalidResponseError as e: except InvalidResponseError as e:
logger.warning("Failed to %s via %s: %s", description, destination, e) logger.warning("Failed to %s via %s: %s", description, destination, e)
@ -402,12 +418,12 @@ class FederationClient(FederationBase):
) )
except Exception: except Exception:
logger.warning( logger.warning(
"Failed to %s via %s", description, destination, exc_info=1 "Failed to %s via %s", description, destination, exc_info=True
) )
raise SynapseError(502, "Failed to %s via any server" % (description,)) raise SynapseError(502, "Failed to %s via any server" % (description,))
def make_membership_event( async def make_membership_event(
self, self,
destinations: Iterable[str], destinations: Iterable[str],
room_id: str, room_id: str,
@ -415,7 +431,7 @@ class FederationClient(FederationBase):
membership: str, membership: str,
content: dict, content: dict,
params: Dict[str, str], params: Dict[str, str],
): ) -> Tuple[str, EventBase, RoomVersion]:
""" """
Creates an m.room.member event, with context, without participating in the room. Creates an m.room.member event, with context, without participating in the room.
@ -436,19 +452,19 @@ class FederationClient(FederationBase):
content: Any additional data to put into the content field of the content: Any additional data to put into the content field of the
event. event.
params: Query parameters to include in the request. params: Query parameters to include in the request.
Return:
Deferred[Tuple[str, FrozenEvent, RoomVersion]]: resolves to a tuple of Returns:
`(origin, event, room_version)` where origin is the remote `(origin, event, room_version)` where origin is the remote
homeserver which generated the event, and room_version is the homeserver which generated the event, and room_version is the
version of the room. version of the room.
Fails with a `UnsupportedRoomVersionError` if remote responds with Raises:
a room version we don't understand. UnsupportedRoomVersionError: if remote responds with
a room version we don't understand.
Fails with a ``SynapseError`` if the chosen remote server SynapseError: if the chosen remote server returns a 300/400 code.
returns a 300/400 code.
Fails with a ``RuntimeError`` if no servers were reachable. RuntimeError: if no servers were reachable.
""" """
valid_memberships = {Membership.JOIN, Membership.LEAVE} valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships: if membership not in valid_memberships:
@ -457,9 +473,8 @@ class FederationClient(FederationBase):
% (membership, ",".join(valid_memberships)) % (membership, ",".join(valid_memberships))
) )
@defer.inlineCallbacks async def send_request(destination: str) -> Tuple[str, EventBase, RoomVersion]:
def send_request(destination): ret = await self.transport_layer.make_membership_event(
ret = yield self.transport_layer.make_membership_event(
destination, room_id, user_id, membership, params destination, room_id, user_id, membership, params
) )
@ -492,33 +507,35 @@ class FederationClient(FederationBase):
event_dict=pdu_dict, event_dict=pdu_dict,
) )
return (destination, ev, room_version) return destination, ev, room_version
return self._try_destination_list( return await self._try_destination_list(
"make_" + membership, destinations, send_request "make_" + membership, destinations, send_request
) )
def send_join(self, destinations, pdu, event_format_version): async def send_join(
self, destinations: Iterable[str], pdu: EventBase, event_format_version: int
) -> Dict[str, Any]:
"""Sends a join event to one of a list of homeservers. """Sends a join event to one of a list of homeservers.
Doing so will cause the remote server to add the event to the graph, Doing so will cause the remote server to add the event to the graph,
and send the event out to the rest of the federation. and send the event out to the rest of the federation.
Args: Args:
destinations (str): Candidate homeservers which are probably destinations: Candidate homeservers which are probably
participating in the room. participating in the room.
pdu (BaseEvent): event to be sent pdu: event to be sent
event_format_version (int): The event format version event_format_version: The event format version
Return: Returns:
Deferred: resolves to a dict with members ``origin`` (a string a dict with members ``origin`` (a string
giving the serer the event was sent to, ``state`` (?) and giving the server the event was sent to, ``state`` (?) and
``auth_chain``. ``auth_chain``.
Fails with a ``SynapseError`` if the chosen remote server Raises:
returns a 300/400 code. SynapseError: if the chosen remote server returns a 300/400 code.
Fails with a ``RuntimeError`` if no servers were reachable. RuntimeError: if no servers were reachable.
""" """
def check_authchain_validity(signed_auth_chain): def check_authchain_validity(signed_auth_chain):
@ -538,9 +555,8 @@ class FederationClient(FederationBase):
"room appears to have unsupported version %s" % (room_version,) "room appears to have unsupported version %s" % (room_version,)
) )
@defer.inlineCallbacks async def send_request(destination) -> Dict[str, Any]:
def send_request(destination): content = await self._do_send_join(destination, pdu)
content = yield self._do_send_join(destination, pdu)
logger.debug("Got content: %s", content) logger.debug("Got content: %s", content)
@ -569,7 +585,7 @@ class FederationClient(FederationBase):
# invalid, and it would fail auth checks anyway. # invalid, and it would fail auth checks anyway.
raise SynapseError(400, "No create event in state") raise SynapseError(400, "No create event in state")
valid_pdus = yield self._check_sigs_and_hash_and_fetch( valid_pdus = await self._check_sigs_and_hash_and_fetch(
destination, destination,
list(pdus.values()), list(pdus.values()),
outlier=True, outlier=True,
@ -605,14 +621,13 @@ class FederationClient(FederationBase):
"origin": destination, "origin": destination,
} }
return self._try_destination_list("send_join", destinations, send_request) return await self._try_destination_list("send_join", destinations, send_request)
@defer.inlineCallbacks async def _do_send_join(self, destination: str, pdu: EventBase):
def _do_send_join(self, destination, pdu):
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
try: try:
content = yield self.transport_layer.send_join_v2( content = await self.transport_layer.send_join_v2(
destination=destination, destination=destination,
room_id=pdu.room_id, room_id=pdu.room_id,
event_id=pdu.event_id, event_id=pdu.event_id,
@ -634,7 +649,7 @@ class FederationClient(FederationBase):
logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API") logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
resp = yield self.transport_layer.send_join_v1( resp = await self.transport_layer.send_join_v1(
destination=destination, destination=destination,
room_id=pdu.room_id, room_id=pdu.room_id,
event_id=pdu.event_id, event_id=pdu.event_id,
@ -645,45 +660,42 @@ class FederationClient(FederationBase):
# content. # content.
return resp[1] return resp[1]
@defer.inlineCallbacks async def send_invite(
def send_invite(self, destination, room_id, event_id, pdu): self, destination: str, room_id: str, event_id: str, pdu: EventBase,
room_version = yield self.store.get_room_version_id(room_id) ) -> EventBase:
room_version = await self.store.get_room_version_id(room_id)
content = yield self._do_send_invite(destination, pdu, room_version) content = await self._do_send_invite(destination, pdu, room_version)
pdu_dict = content["event"] pdu_dict = content["event"]
logger.debug("Got response to send_invite: %s", pdu_dict) logger.debug("Got response to send_invite: %s", pdu_dict)
room_version = yield self.store.get_room_version_id(room_id) room_version = await self.store.get_room_version_id(room_id)
format_ver = room_version_to_event_format(room_version) format_ver = room_version_to_event_format(room_version)
pdu = event_from_pdu_json(pdu_dict, format_ver) pdu = event_from_pdu_json(pdu_dict, format_ver)
# Check signatures are correct. # Check signatures are correct.
pdu = yield self._check_sigs_and_hash(room_version, pdu) pdu = await self._check_sigs_and_hash(room_version, pdu)
# FIXME: We should handle signature failures more gracefully. # FIXME: We should handle signature failures more gracefully.
return pdu return pdu
@defer.inlineCallbacks async def _do_send_invite(
def _do_send_invite(self, destination, pdu, room_version): self, destination: str, pdu: EventBase, room_version: str
) -> JsonDict:
"""Actually sends the invite, first trying v2 API and falling back to """Actually sends the invite, first trying v2 API and falling back to
v1 API if necessary. v1 API if necessary.
Args:
destination (str): Target server
pdu (FrozenEvent)
room_version (str)
Returns: Returns:
dict: The event as a dict as returned by the remote server The event as a dict as returned by the remote server
""" """
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
try: try:
content = yield self.transport_layer.send_invite_v2( content = await self.transport_layer.send_invite_v2(
destination=destination, destination=destination,
room_id=pdu.room_id, room_id=pdu.room_id,
event_id=pdu.event_id, event_id=pdu.event_id,
@ -722,7 +734,7 @@ class FederationClient(FederationBase):
# Didn't work, try v1 API. # Didn't work, try v1 API.
# Note the v1 API returns a tuple of `(200, content)` # Note the v1 API returns a tuple of `(200, content)`
_, content = yield self.transport_layer.send_invite_v1( _, content = await self.transport_layer.send_invite_v1(
destination=destination, destination=destination,
room_id=pdu.room_id, room_id=pdu.room_id,
event_id=pdu.event_id, event_id=pdu.event_id,
@ -730,7 +742,7 @@ class FederationClient(FederationBase):
) )
return content return content
def send_leave(self, destinations, pdu): async def send_leave(self, destinations: Iterable[str], pdu: EventBase) -> None:
"""Sends a leave event to one of a list of homeservers. """Sends a leave event to one of a list of homeservers.
Doing so will cause the remote server to add the event to the graph, Doing so will cause the remote server to add the event to the graph,
@ -739,34 +751,29 @@ class FederationClient(FederationBase):
This is mostly useful to reject received invites. This is mostly useful to reject received invites.
Args: Args:
destinations (str): Candidate homeservers which are probably destinations: Candidate homeservers which are probably
participating in the room. participating in the room.
pdu (BaseEvent): event to be sent pdu: event to be sent
Return: Raises:
Deferred: resolves to None. SynapseError if the chosen remote server returns a 300/400 code.
Fails with a ``SynapseError`` if the chosen remote server RuntimeError if no servers were reachable.
returns a 300/400 code.
Fails with a ``RuntimeError`` if no servers were reachable.
""" """
@defer.inlineCallbacks async def send_request(destination: str) -> None:
def send_request(destination): content = await self._do_send_leave(destination, pdu)
content = yield self._do_send_leave(destination, pdu)
logger.debug("Got content: %s", content) logger.debug("Got content: %s", content)
return None
return self._try_destination_list("send_leave", destinations, send_request) return await self._try_destination_list(
"send_leave", destinations, send_request
)
@defer.inlineCallbacks async def _do_send_leave(self, destination, pdu):
def _do_send_leave(self, destination, pdu):
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
try: try:
content = yield self.transport_layer.send_leave_v2( content = await self.transport_layer.send_leave_v2(
destination=destination, destination=destination,
room_id=pdu.room_id, room_id=pdu.room_id,
event_id=pdu.event_id, event_id=pdu.event_id,
@ -788,7 +795,7 @@ class FederationClient(FederationBase):
logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API") logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API")
resp = yield self.transport_layer.send_leave_v1( resp = await self.transport_layer.send_leave_v1(
destination=destination, destination=destination,
room_id=pdu.room_id, room_id=pdu.room_id,
event_id=pdu.event_id, event_id=pdu.event_id,
@ -820,34 +827,33 @@ class FederationClient(FederationBase):
third_party_instance_id=third_party_instance_id, third_party_instance_id=third_party_instance_id,
) )
@defer.inlineCallbacks async def get_missing_events(
def get_missing_events(
self, self,
destination, destination: str,
room_id, room_id: str,
earliest_events_ids, earliest_events_ids: Sequence[str],
latest_events, latest_events: Iterable[EventBase],
limit, limit: int,
min_depth, min_depth: int,
timeout, timeout: int,
): ) -> List[EventBase]:
"""Tries to fetch events we are missing. This is called when we receive """Tries to fetch events we are missing. This is called when we receive
an event without having received all of its ancestors. an event without having received all of its ancestors.
Args: Args:
destination (str) destination
room_id (str) room_id
earliest_events_ids (list): List of event ids. Effectively the earliest_events_ids: List of event ids. Effectively the
events we expected to receive, but haven't. `get_missing_events` events we expected to receive, but haven't. `get_missing_events`
should only return events that didn't happen before these. should only return events that didn't happen before these.
latest_events (list): List of events we have received that we don't latest_events: List of events we have received that we don't
have all previous events for. have all previous events for.
limit (int): Maximum number of events to return. limit: Maximum number of events to return.
min_depth (int): Minimum depth of events tor return. min_depth: Minimum depth of events to return.
timeout (int): Max time to wait in ms timeout: Max time to wait in ms
""" """
try: try:
content = yield self.transport_layer.get_missing_events( content = await self.transport_layer.get_missing_events(
destination=destination, destination=destination,
room_id=room_id, room_id=room_id,
earliest_events=earliest_events_ids, earliest_events=earliest_events_ids,
@ -857,14 +863,14 @@ class FederationClient(FederationBase):
timeout=timeout, timeout=timeout,
) )
room_version = yield self.store.get_room_version_id(room_id) room_version = await self.store.get_room_version_id(room_id)
format_ver = room_version_to_event_format(room_version) format_ver = room_version_to_event_format(room_version)
events = [ events = [
event_from_pdu_json(e, format_ver) for e in content.get("events", []) event_from_pdu_json(e, format_ver) for e in content.get("events", [])
] ]
signed_events = yield self._check_sigs_and_hash_and_fetch( signed_events = await self._check_sigs_and_hash_and_fetch(
destination, events, outlier=False, room_version=room_version destination, events, outlier=False, room_version=room_version
) )
except HttpResponseException as e: except HttpResponseException as e: