mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-21 17:15:38 +03:00
Finish type hints for federation client HTTP code. (#15465)
This commit is contained in:
parent
19141b9432
commit
ea5c3ede4f
7 changed files with 82 additions and 42 deletions
1
changelog.d/15465.misc
Normal file
1
changelog.d/15465.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Improve type hints.
|
6
mypy.ini
6
mypy.ini
|
@ -33,12 +33,6 @@ exclude = (?x)
|
|||
|synapse/storage/schema/
|
||||
)$
|
||||
|
||||
[mypy-synapse.federation.transport.client]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-synapse.http.matrixfederationclient]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-synapse.metrics._reactor_metrics]
|
||||
disallow_untyped_defs = False
|
||||
# This module imports select.epoll. That exists on Linux, but doesn't on macOS.
|
||||
|
|
|
@ -280,15 +280,11 @@ class FederationClient(FederationBase):
|
|||
logger.debug("backfill transaction_data=%r", transaction_data)
|
||||
|
||||
if not isinstance(transaction_data, dict):
|
||||
# TODO we probably want an exception type specific to federation
|
||||
# client validation.
|
||||
raise TypeError("Backfill transaction_data is not a dict.")
|
||||
raise InvalidResponseError("Backfill transaction_data is not a dict.")
|
||||
|
||||
transaction_data_pdus = transaction_data.get("pdus")
|
||||
if not isinstance(transaction_data_pdus, list):
|
||||
# TODO we probably want an exception type specific to federation
|
||||
# client validation.
|
||||
raise TypeError("transaction_data.pdus is not a list.")
|
||||
raise InvalidResponseError("transaction_data.pdus is not a list.")
|
||||
|
||||
room_version = await self.store.get_room_version(room_id)
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import logging
|
||||
import urllib
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Collection,
|
||||
|
@ -42,18 +43,21 @@ from synapse.api.urls import (
|
|||
)
|
||||
from synapse.events import EventBase, make_event_from_dict
|
||||
from synapse.federation.units import Transaction
|
||||
from synapse.http.matrixfederationclient import ByteParser
|
||||
from synapse.http.matrixfederationclient import ByteParser, LegacyJsonSendParser
|
||||
from synapse.http.types import QueryParams
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import ExceptionBundle
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TransportLayerClient:
|
||||
"""Sends federation HTTP requests to other servers"""
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.server_name = hs.hostname
|
||||
self.client = hs.get_federation_http_client()
|
||||
self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
|
||||
|
@ -133,7 +137,7 @@ class TransportLayerClient:
|
|||
|
||||
async def backfill(
|
||||
self, destination: str, room_id: str, event_tuples: Collection[str], limit: int
|
||||
) -> Optional[JsonDict]:
|
||||
) -> Optional[Union[JsonDict, list]]:
|
||||
"""Requests `limit` previous PDUs in a given context before list of
|
||||
PDUs.
|
||||
|
||||
|
@ -388,6 +392,7 @@ class TransportLayerClient:
|
|||
# server was just having a momentary blip, the room will be out of
|
||||
# sync.
|
||||
ignore_backoff=True,
|
||||
parser=LegacyJsonSendParser(),
|
||||
)
|
||||
|
||||
async def send_leave_v2(
|
||||
|
@ -445,7 +450,11 @@ class TransportLayerClient:
|
|||
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
|
||||
|
||||
return await self.client.put_json(
|
||||
destination=destination, path=path, data=content, ignore_backoff=True
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
parser=LegacyJsonSendParser(),
|
||||
)
|
||||
|
||||
async def send_invite_v2(
|
||||
|
|
|
@ -17,7 +17,6 @@ import codecs
|
|||
import logging
|
||||
import random
|
||||
import sys
|
||||
import typing
|
||||
import urllib.parse
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO, StringIO
|
||||
|
@ -30,9 +29,11 @@ from typing import (
|
|||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
TextIO,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
|
@ -183,20 +184,61 @@ class MatrixFederationRequest:
|
|||
return self.json
|
||||
|
||||
|
||||
class JsonParser(ByteParser[Union[JsonDict, list]]):
|
||||
class _BaseJsonParser(ByteParser[T]):
|
||||
"""A parser that buffers the response and tries to parse it as JSON."""
|
||||
|
||||
CONTENT_TYPE = "application/json"
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(
|
||||
self, validator: Optional[Callable[[Optional[object]], bool]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
validator: A callable which takes the parsed JSON value and returns
|
||||
true if the value is valid.
|
||||
"""
|
||||
self._buffer = StringIO()
|
||||
self._binary_wrapper = BinaryIOWrapper(self._buffer)
|
||||
self._validator = validator
|
||||
|
||||
def write(self, data: bytes) -> int:
|
||||
return self._binary_wrapper.write(data)
|
||||
|
||||
def finish(self) -> Union[JsonDict, list]:
|
||||
return json_decoder.decode(self._buffer.getvalue())
|
||||
def finish(self) -> T:
|
||||
result = json_decoder.decode(self._buffer.getvalue())
|
||||
if self._validator is not None and not self._validator(result):
|
||||
raise ValueError(
|
||||
f"Received incorrect JSON value: {result.__class__.__name__}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class JsonParser(_BaseJsonParser[JsonDict]):
|
||||
"""A parser that buffers the response and tries to parse it as a JSON object."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(self._validate)
|
||||
|
||||
@staticmethod
|
||||
def _validate(v: Any) -> bool:
|
||||
return isinstance(v, dict)
|
||||
|
||||
|
||||
class LegacyJsonSendParser(_BaseJsonParser[Tuple[int, JsonDict]]):
|
||||
"""Ensure the legacy responses of /send_join & /send_leave are correct."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(self._validate)
|
||||
|
||||
@staticmethod
|
||||
def _validate(v: Any) -> bool:
|
||||
# Match [integer, JSON dict]
|
||||
return (
|
||||
isinstance(v, list)
|
||||
and len(v) == 2
|
||||
and type(v[0]) == int
|
||||
and isinstance(v[1], dict)
|
||||
)
|
||||
|
||||
|
||||
async def _handle_response(
|
||||
|
@ -313,9 +355,7 @@ async def _handle_response(
|
|||
class BinaryIOWrapper:
|
||||
"""A wrapper for a TextIO which converts from bytes on the fly."""
|
||||
|
||||
def __init__(
|
||||
self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict"
|
||||
):
|
||||
def __init__(self, file: TextIO, encoding: str = "utf-8", errors: str = "strict"):
|
||||
self.decoder = codecs.getincrementaldecoder(encoding)(errors)
|
||||
self.file = file
|
||||
|
||||
|
@ -793,7 +833,7 @@ class MatrixFederationHttpClient:
|
|||
backoff_on_404: bool = False,
|
||||
try_trailing_slash_on_400: bool = False,
|
||||
parser: Literal[None] = None,
|
||||
) -> Union[JsonDict, list]:
|
||||
) -> JsonDict:
|
||||
...
|
||||
|
||||
@overload
|
||||
|
@ -825,8 +865,8 @@ class MatrixFederationHttpClient:
|
|||
ignore_backoff: bool = False,
|
||||
backoff_on_404: bool = False,
|
||||
try_trailing_slash_on_400: bool = False,
|
||||
parser: Optional[ByteParser] = None,
|
||||
):
|
||||
parser: Optional[ByteParser[T]] = None,
|
||||
) -> Union[JsonDict, T]:
|
||||
"""Sends the specified json data using PUT
|
||||
|
||||
Args:
|
||||
|
@ -902,7 +942,7 @@ class MatrixFederationHttpClient:
|
|||
_sec_timeout = self.default_timeout
|
||||
|
||||
if parser is None:
|
||||
parser = JsonParser()
|
||||
parser = cast(ByteParser[T], JsonParser())
|
||||
|
||||
body = await _handle_response(
|
||||
self.reactor,
|
||||
|
@ -924,7 +964,7 @@ class MatrixFederationHttpClient:
|
|||
timeout: Optional[int] = None,
|
||||
ignore_backoff: bool = False,
|
||||
args: Optional[QueryParams] = None,
|
||||
) -> Union[JsonDict, list]:
|
||||
) -> JsonDict:
|
||||
"""Sends the specified json data using POST
|
||||
|
||||
Args:
|
||||
|
@ -998,7 +1038,7 @@ class MatrixFederationHttpClient:
|
|||
ignore_backoff: bool = False,
|
||||
try_trailing_slash_on_400: bool = False,
|
||||
parser: Literal[None] = None,
|
||||
) -> Union[JsonDict, list]:
|
||||
) -> JsonDict:
|
||||
...
|
||||
|
||||
@overload
|
||||
|
@ -1024,8 +1064,8 @@ class MatrixFederationHttpClient:
|
|||
timeout: Optional[int] = None,
|
||||
ignore_backoff: bool = False,
|
||||
try_trailing_slash_on_400: bool = False,
|
||||
parser: Optional[ByteParser] = None,
|
||||
):
|
||||
parser: Optional[ByteParser[T]] = None,
|
||||
) -> Union[JsonDict, T]:
|
||||
"""GETs some json from the given host homeserver and path
|
||||
|
||||
Args:
|
||||
|
@ -1091,7 +1131,7 @@ class MatrixFederationHttpClient:
|
|||
_sec_timeout = self.default_timeout
|
||||
|
||||
if parser is None:
|
||||
parser = JsonParser()
|
||||
parser = cast(ByteParser[T], JsonParser())
|
||||
|
||||
body = await _handle_response(
|
||||
self.reactor,
|
||||
|
@ -1112,7 +1152,7 @@ class MatrixFederationHttpClient:
|
|||
timeout: Optional[int] = None,
|
||||
ignore_backoff: bool = False,
|
||||
args: Optional[QueryParams] = None,
|
||||
) -> Union[JsonDict, list]:
|
||||
) -> JsonDict:
|
||||
"""Send a DELETE request to the remote expecting some json response
|
||||
|
||||
Args:
|
||||
|
|
|
@ -75,7 +75,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
|||
fed_transport = self.hs.get_federation_transport_client()
|
||||
|
||||
# Mock out some things, because we don't want to test the whole join
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
|
||||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
||||
return_value=make_awaitable(("", 1))
|
||||
)
|
||||
|
@ -106,7 +106,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
|||
fed_transport = self.hs.get_federation_transport_client()
|
||||
|
||||
# Mock out some things, because we don't want to test the whole join
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
|
||||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
||||
return_value=make_awaitable(("", 1))
|
||||
)
|
||||
|
@ -143,7 +143,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
|||
fed_transport = self.hs.get_federation_transport_client()
|
||||
|
||||
# Mock out some things, because we don't want to test the whole join
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
|
||||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
||||
return_value=make_awaitable(("", 1))
|
||||
)
|
||||
|
@ -200,7 +200,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
|
|||
fed_transport = self.hs.get_federation_transport_client()
|
||||
|
||||
# Mock out some things, because we don't want to test the whole join
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
|
||||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
||||
return_value=make_awaitable(("", 1))
|
||||
)
|
||||
|
@ -230,7 +230,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
|
|||
fed_transport = self.hs.get_federation_transport_client()
|
||||
|
||||
# Mock out some things, because we don't want to test the whole join
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
|
||||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
||||
return_value=make_awaitable(("", 1))
|
||||
)
|
||||
|
|
|
@ -26,7 +26,7 @@ from twisted.web.http import HTTPChannel
|
|||
|
||||
from synapse.api.errors import RequestSendFailed
|
||||
from synapse.http.matrixfederationclient import (
|
||||
JsonParser,
|
||||
ByteParser,
|
||||
MatrixFederationHttpClient,
|
||||
MatrixFederationRequest,
|
||||
)
|
||||
|
@ -618,9 +618,9 @@ class FederationClientTests(HomeserverTestCase):
|
|||
while not test_d.called:
|
||||
protocol.dataReceived(b"a" * chunk_size)
|
||||
sent += chunk_size
|
||||
self.assertLessEqual(sent, JsonParser.MAX_RESPONSE_SIZE)
|
||||
self.assertLessEqual(sent, ByteParser.MAX_RESPONSE_SIZE)
|
||||
|
||||
self.assertEqual(sent, JsonParser.MAX_RESPONSE_SIZE)
|
||||
self.assertEqual(sent, ByteParser.MAX_RESPONSE_SIZE)
|
||||
|
||||
f = self.failureResultOf(test_d)
|
||||
self.assertIsInstance(f.value, RequestSendFailed)
|
||||
|
|
Loading…
Reference in a new issue