Fix additional type hints from Twisted 21.2.0. (#9591)

This commit is contained in:
Patrick Cloke 2021-03-12 11:37:57 -05:00 committed by GitHub
parent 1e67bff833
commit 55da8df078
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 187 additions and 119 deletions

1
changelog.d/9591.misc Normal file
View file

@ -0,0 +1 @@
Fix incorrect type hints.

View file

@ -164,7 +164,7 @@ class Auth:
async def get_user_by_req( async def get_user_by_req(
self, self,
request: Request, request: SynapseRequest,
allow_guest: bool = False, allow_guest: bool = False,
rights: str = "access", rights: str = "access",
allow_expired: bool = False, allow_expired: bool = False,

View file

@ -880,7 +880,9 @@ class FederationHandlerRegistry:
self.edu_handlers = ( self.edu_handlers = (
{} {}
) # type: Dict[str, Callable[[str, dict], Awaitable[None]]] ) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]] self.query_handlers = (
{}
) # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]]
# Map from type to instance names that we should route EDU handling to. # Map from type to instance names that we should route EDU handling to.
# We randomly choose one instance from the list to route to for each new # We randomly choose one instance from the list to route to for each new
@ -914,7 +916,7 @@ class FederationHandlerRegistry:
self.edu_handlers[edu_type] = handler self.edu_handlers[edu_type] = handler
def register_query_handler( def register_query_handler(
self, query_type: str, handler: Callable[[dict], defer.Deferred] self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
): ):
"""Sets the handler callable that will be used to handle an incoming """Sets the handler callable that will be used to handle an incoming
federation query of the given type. federation query of the given type.
@ -987,7 +989,7 @@ class FederationHandlerRegistry:
# Oh well, let's just log and move on. # Oh well, let's just log and move on.
logger.warning("No handler registered for EDU type %s", edu_type) logger.warning("No handler registered for EDU type %s", edu_type)
async def on_query(self, query_type: str, args: dict): async def on_query(self, query_type: str, args: dict) -> JsonDict:
handler = self.query_handlers.get(query_type) handler = self.query_handlers.get(query_type)
if handler: if handler:
return await handler(args) return await handler(args)

View file

@ -34,6 +34,7 @@ from pymacaroons.exceptions import (
from typing_extensions import TypedDict from typing_extensions import TypedDict
from twisted.web.client import readBody from twisted.web.client import readBody
from twisted.web.http_headers import Headers
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.config.oidc_config import ( from synapse.config.oidc_config import (
@ -538,7 +539,7 @@ class OidcProvider:
""" """
metadata = await self.load_metadata() metadata = await self.load_metadata()
token_endpoint = metadata.get("token_endpoint") token_endpoint = metadata.get("token_endpoint")
headers = { raw_headers = {
"Content-Type": "application/x-www-form-urlencoded", "Content-Type": "application/x-www-form-urlencoded",
"User-Agent": self._http_client.user_agent, "User-Agent": self._http_client.user_agent,
"Accept": "application/json", "Accept": "application/json",
@ -552,10 +553,10 @@ class OidcProvider:
body = urlencode(args, True) body = urlencode(args, True)
# Fill the body/headers with credentials # Fill the body/headers with credentials
uri, headers, body = self._client_auth.prepare( uri, raw_headers, body = self._client_auth.prepare(
method="POST", uri=token_endpoint, headers=headers, body=body method="POST", uri=token_endpoint, headers=raw_headers, body=body
) )
headers = {k: [v] for (k, v) in headers.items()} headers = Headers({k: [v] for (k, v) in raw_headers.items()})
# Do the actual request # Do the actual request
# We're not using the SimpleHttpClient util methods as we don't want to # We're not using the SimpleHttpClient util methods as we don't want to

View file

@ -57,7 +57,13 @@ from twisted.web.client import (
) )
from twisted.web.http import PotentialDataLoss from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse from twisted.web.iweb import (
UNKNOWN_LENGTH,
IAgent,
IBodyProducer,
IPolicyForHTTPS,
IResponse,
)
from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
@ -870,6 +876,7 @@ def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> by
return query_str.encode("utf8") return query_str.encode("utf8")
@implementer(IPolicyForHTTPS)
class InsecureInterceptableContextFactory(ssl.ContextFactory): class InsecureInterceptableContextFactory(ssl.ContextFactory):
""" """
Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain. Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain.

View file

@ -32,8 +32,9 @@ from twisted.internet.endpoints import (
TCP4ClientEndpoint, TCP4ClientEndpoint,
TCP6ClientEndpoint, TCP6ClientEndpoint,
) )
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint
from twisted.internet.protocol import Factory, Protocol from twisted.internet.protocol import Factory, Protocol
from twisted.internet.tcp import Connection
from twisted.python.failure import Failure from twisted.python.failure import Failure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -52,7 +53,9 @@ class LogProducer:
format: A callable to format the log record to a string. format: A callable to format the log record to a string.
""" """
transport = attr.ib(type=ITransport) # This is essentially ITCPTransport, but that is missing certain fields
# (connected and registerProducer) which are part of the implementation.
transport = attr.ib(type=Connection)
_format = attr.ib(type=Callable[[logging.LogRecord], str]) _format = attr.ib(type=Callable[[logging.LogRecord], str])
_buffer = attr.ib(type=deque) _buffer = attr.ib(type=deque)
_paused = attr.ib(default=False, type=bool, init=False) _paused = attr.ib(default=False, type=bool, init=False)
@ -149,8 +152,6 @@ class RemoteHandler(logging.Handler):
if self._connection_waiter: if self._connection_waiter:
return return
self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
def fail(failure: Failure) -> None: def fail(failure: Failure) -> None:
# If the Deferred was cancelled (e.g. during shutdown) do not try to # If the Deferred was cancelled (e.g. during shutdown) do not try to
# reconnect (this will cause an infinite loop of errors). # reconnect (this will cause an infinite loop of errors).
@ -163,9 +164,13 @@ class RemoteHandler(logging.Handler):
self._connect() self._connect()
def writer(result: Protocol) -> None: def writer(result: Protocol) -> None:
# Force recognising transport as a Connection and not the more
# generic ITransport.
transport = result.transport # type: Connection # type: ignore
# We have a connection. If we already have a producer, and its # We have a connection. If we already have a producer, and its
# transport is the same, just trigger a resumeProducing. # transport is the same, just trigger a resumeProducing.
if self._producer and result.transport is self._producer.transport: if self._producer and transport is self._producer.transport:
self._producer.resumeProducing() self._producer.resumeProducing()
self._connection_waiter = None self._connection_waiter = None
return return
@ -177,14 +182,16 @@ class RemoteHandler(logging.Handler):
# Make a new producer and start it. # Make a new producer and start it.
self._producer = LogProducer( self._producer = LogProducer(
buffer=self._buffer, buffer=self._buffer,
transport=result.transport, transport=transport,
format=self.format, format=self.format,
) )
result.transport.registerProducer(self._producer, True) transport.registerProducer(self._producer, True)
self._producer.resumeProducing() self._producer.resumeProducing()
self._connection_waiter = None self._connection_waiter = None
self._connection_waiter.addCallbacks(writer, fail) deferred = self._service.whenConnected(failAfterFailures=1) # type: Deferred
deferred.addCallbacks(writer, fail)
self._connection_waiter = deferred
def _handle_pressure(self) -> None: def _handle_pressure(self) -> None:
""" """

View file

@ -16,8 +16,8 @@
import logging import logging
from typing import TYPE_CHECKING, Dict, List, Optional from typing import TYPE_CHECKING, Dict, List, Optional
from twisted.internet.base import DelayedCall
from twisted.internet.error import AlreadyCalled, AlreadyCancelled from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from twisted.internet.interfaces import IDelayedCall
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import Pusher, PusherConfig, ThrottleParams from synapse.push import Pusher, PusherConfig, ThrottleParams
@ -66,7 +66,7 @@ class EmailPusher(Pusher):
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.email = pusher_config.pushkey self.email = pusher_config.pushkey
self.timed_call = None # type: Optional[DelayedCall] self.timed_call = None # type: Optional[IDelayedCall]
self.throttle_params = {} # type: Dict[str, ThrottleParams] self.throttle_params = {} # type: Dict[str, ThrottleParams]
self._inited = False self._inited = False

View file

@ -48,7 +48,7 @@ from synapse.replication.tcp.commands import (
UserIpCommand, UserIpCommand,
UserSyncCommand, UserSyncCommand,
) )
from synapse.replication.tcp.protocol import AbstractConnection from synapse.replication.tcp.protocol import IReplicationConnection
from synapse.replication.tcp.streams import ( from synapse.replication.tcp.streams import (
STREAMS_MAP, STREAMS_MAP,
AccountDataStream, AccountDataStream,
@ -82,7 +82,7 @@ user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache"
# the type of the entries in _command_queues_by_stream # the type of the entries in _command_queues_by_stream
_StreamCommandQueue = Deque[ _StreamCommandQueue = Deque[
Tuple[Union[RdataCommand, PositionCommand], AbstractConnection] Tuple[Union[RdataCommand, PositionCommand], IReplicationConnection]
] ]
@ -174,7 +174,7 @@ class ReplicationCommandHandler:
# The currently connected connections. (The list of places we need to send # The currently connected connections. (The list of places we need to send
# outgoing replication commands to.) # outgoing replication commands to.)
self._connections = [] # type: List[AbstractConnection] self._connections = [] # type: List[IReplicationConnection]
LaterGauge( LaterGauge(
"synapse_replication_tcp_resource_total_connections", "synapse_replication_tcp_resource_total_connections",
@ -197,7 +197,7 @@ class ReplicationCommandHandler:
# For each connection, the incoming stream names that have received a POSITION # For each connection, the incoming stream names that have received a POSITION
# from that connection. # from that connection.
self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] self._streams_by_connection = {} # type: Dict[IReplicationConnection, Set[str]]
LaterGauge( LaterGauge(
"synapse_replication_tcp_command_queue", "synapse_replication_tcp_command_queue",
@ -220,7 +220,7 @@ class ReplicationCommandHandler:
self._server_notices_sender = hs.get_server_notices_sender() self._server_notices_sender = hs.get_server_notices_sender()
def _add_command_to_stream_queue( def _add_command_to_stream_queue(
self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand] self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None: ) -> None:
"""Queue the given received command for processing """Queue the given received command for processing
@ -267,7 +267,7 @@ class ReplicationCommandHandler:
async def _process_command( async def _process_command(
self, self,
cmd: Union[PositionCommand, RdataCommand], cmd: Union[PositionCommand, RdataCommand],
conn: AbstractConnection, conn: IReplicationConnection,
stream_name: str, stream_name: str,
) -> None: ) -> None:
if isinstance(cmd, PositionCommand): if isinstance(cmd, PositionCommand):
@ -321,10 +321,10 @@ class ReplicationCommandHandler:
"""Get a list of streams that this instances replicates.""" """Get a list of streams that this instances replicates."""
return self._streams_to_replicate return self._streams_to_replicate
def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand):
self.send_positions_to_connection(conn) self.send_positions_to_connection(conn)
def send_positions_to_connection(self, conn: AbstractConnection): def send_positions_to_connection(self, conn: IReplicationConnection):
"""Send current position of all streams this process is source of to """Send current position of all streams this process is source of to
the connection. the connection.
""" """
@ -347,7 +347,7 @@ class ReplicationCommandHandler:
) )
def on_USER_SYNC( def on_USER_SYNC(
self, conn: AbstractConnection, cmd: UserSyncCommand self, conn: IReplicationConnection, cmd: UserSyncCommand
) -> Optional[Awaitable[None]]: ) -> Optional[Awaitable[None]]:
user_sync_counter.inc() user_sync_counter.inc()
@ -359,21 +359,23 @@ class ReplicationCommandHandler:
return None return None
def on_CLEAR_USER_SYNC( def on_CLEAR_USER_SYNC(
self, conn: AbstractConnection, cmd: ClearUserSyncsCommand self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand
) -> Optional[Awaitable[None]]: ) -> Optional[Awaitable[None]]:
if self._is_master: if self._is_master:
return self._presence_handler.update_external_syncs_clear(cmd.instance_id) return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
else: else:
return None return None
def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand): def on_FEDERATION_ACK(
self, conn: IReplicationConnection, cmd: FederationAckCommand
):
federation_ack_counter.inc() federation_ack_counter.inc()
if self._federation_sender: if self._federation_sender:
self._federation_sender.federation_ack(cmd.instance_name, cmd.token) self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
def on_USER_IP( def on_USER_IP(
self, conn: AbstractConnection, cmd: UserIpCommand self, conn: IReplicationConnection, cmd: UserIpCommand
) -> Optional[Awaitable[None]]: ) -> Optional[Awaitable[None]]:
user_ip_cache_counter.inc() user_ip_cache_counter.inc()
@ -395,7 +397,7 @@ class ReplicationCommandHandler:
assert self._server_notices_sender is not None assert self._server_notices_sender is not None
await self._server_notices_sender.on_user_ip(cmd.user_id) await self._server_notices_sender.on_user_ip(cmd.user_id)
def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand):
if cmd.instance_name == self._instance_name: if cmd.instance_name == self._instance_name:
# Ignore RDATA that are just our own echoes # Ignore RDATA that are just our own echoes
return return
@ -412,7 +414,7 @@ class ReplicationCommandHandler:
self._add_command_to_stream_queue(conn, cmd) self._add_command_to_stream_queue(conn, cmd)
async def _process_rdata( async def _process_rdata(
self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand
) -> None: ) -> None:
"""Process an RDATA command """Process an RDATA command
@ -486,7 +488,7 @@ class ReplicationCommandHandler:
stream_name, instance_name, token, rows stream_name, instance_name, token, rows
) )
def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name: if cmd.instance_name == self._instance_name:
# Ignore POSITION that are just our own echoes # Ignore POSITION that are just our own echoes
return return
@ -496,7 +498,7 @@ class ReplicationCommandHandler:
self._add_command_to_stream_queue(conn, cmd) self._add_command_to_stream_queue(conn, cmd)
async def _process_position( async def _process_position(
self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand
) -> None: ) -> None:
"""Process a POSITION command """Process a POSITION command
@ -553,7 +555,9 @@ class ReplicationCommandHandler:
self._streams_by_connection.setdefault(conn, set()).add(stream_name) self._streams_by_connection.setdefault(conn, set()).add(stream_name)
def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand): def on_REMOTE_SERVER_UP(
self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
):
""""Called when get a new REMOTE_SERVER_UP command.""" """"Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data) self._replication_data_handler.on_remote_server_up(cmd.data)
@ -576,7 +580,7 @@ class ReplicationCommandHandler:
# between two instances, but that is not currently supported). # between two instances, but that is not currently supported).
self.send_command(cmd, ignore_conn=conn) self.send_command(cmd, ignore_conn=conn)
def new_connection(self, connection: AbstractConnection): def new_connection(self, connection: IReplicationConnection):
"""Called when we have a new connection.""" """Called when we have a new connection."""
self._connections.append(connection) self._connections.append(connection)
@ -603,7 +607,7 @@ class ReplicationCommandHandler:
UserSyncCommand(self._instance_id, user_id, True, now) UserSyncCommand(self._instance_id, user_id, True, now)
) )
def lost_connection(self, connection: AbstractConnection): def lost_connection(self, connection: IReplicationConnection):
"""Called when a connection is closed/lost.""" """Called when a connection is closed/lost."""
# we no longer need _streams_by_connection for this connection. # we no longer need _streams_by_connection for this connection.
streams = self._streams_by_connection.pop(connection, None) streams = self._streams_by_connection.pop(connection, None)
@ -624,7 +628,7 @@ class ReplicationCommandHandler:
return bool(self._connections) return bool(self._connections)
def send_command( def send_command(
self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
): ):
"""Send a command to all connected connections. """Send a command to all connected connections.

View file

@ -46,7 +46,6 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping > ERROR server stopping
* connection closed by server * * connection closed by server *
""" """
import abc
import fcntl import fcntl
import logging import logging
import struct import struct
@ -54,6 +53,7 @@ from inspect import isawaitable
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from prometheus_client import Counter from prometheus_client import Counter
from zope.interface import Interface, implementer
from twisted.internet import task from twisted.internet import task
from twisted.protocols.basic import LineOnlyReceiver from twisted.protocols.basic import LineOnlyReceiver
@ -121,6 +121,14 @@ class ConnectionStates:
CLOSED = "closed" CLOSED = "closed"
class IReplicationConnection(Interface):
"""An interface for replication connections."""
def send_command(cmd: Command):
"""Send the command down the connection"""
@implementer(IReplicationConnection)
class BaseReplicationStreamProtocol(LineOnlyReceiver): class BaseReplicationStreamProtocol(LineOnlyReceiver):
"""Base replication protocol shared between client and server. """Base replication protocol shared between client and server.
@ -495,20 +503,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_command(ReplicateCommand()) self.send_command(ReplicateCommand())
class AbstractConnection(abc.ABC):
"""An interface for replication connections."""
@abc.abstractmethod
def send_command(self, cmd: Command):
"""Send the command down the connection"""
pass
# This tells python that `BaseReplicationStreamProtocol` implements the
# interface.
AbstractConnection.register(BaseReplicationStreamProtocol)
# The following simply registers metrics for the replication connections # The following simply registers metrics for the replication connections
pending_commands = LaterGauge( pending_commands = LaterGauge(

View file

@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast
import attr import attr
import txredisapi import txredisapi
from zope.interface import implementer
from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import IAddress, IConnector from twisted.internet.interfaces import IAddress, IConnector
@ -36,7 +37,7 @@ from synapse.replication.tcp.commands import (
parse_command_from_line, parse_command_from_line,
) )
from synapse.replication.tcp.protocol import ( from synapse.replication.tcp.protocol import (
AbstractConnection, IReplicationConnection,
tcp_inbound_commands_counter, tcp_inbound_commands_counter,
tcp_outbound_commands_counter, tcp_outbound_commands_counter,
) )
@ -66,7 +67,8 @@ class ConstantProperty(Generic[T, V]):
pass pass
class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): @implementer(IReplicationConnection)
class RedisSubscriber(txredisapi.SubscriberProtocol):
"""Connection to redis subscribed to replication stream. """Connection to redis subscribed to replication stream.
This class fulfils two functions: This class fulfils two functions:
@ -75,7 +77,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
connection, parsing *incoming* messages into replication commands, and passing them connection, parsing *incoming* messages into replication commands, and passing them
to `ReplicationCommandHandler` to `ReplicationCommandHandler`
(b) it implements the AbstractConnection API, where it sends *outgoing* commands (b) it implements the IReplicationConnection API, where it sends *outgoing* commands
onto outbound_redis_connection. onto outbound_redis_connection.
Due to the vagaries of `txredisapi` we don't want to have a custom Due to the vagaries of `txredisapi` we don't want to have a custom

View file

@ -15,10 +15,9 @@
import re import re
import twisted.web.server from synapse.api.auth import Auth
import synapse.api.auth
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.http.site import SynapseRequest
from synapse.types import UserID from synapse.types import UserID
@ -37,13 +36,11 @@ def admin_patterns(path_regex: str, version: str = "v1"):
return patterns return patterns
async def assert_requester_is_admin( async def assert_requester_is_admin(auth: Auth, request: SynapseRequest) -> None:
auth: synapse.api.auth.Auth, request: twisted.web.server.Request
) -> None:
"""Verify that the requester is an admin user """Verify that the requester is an admin user
Args: Args:
auth: api.auth.Auth singleton auth: Auth singleton
request: incoming request request: incoming request
Raises: Raises:
@ -53,11 +50,11 @@ async def assert_requester_is_admin(
await assert_user_is_admin(auth, requester.user) await assert_user_is_admin(auth, requester.user)
async def assert_user_is_admin(auth: synapse.api.auth.Auth, user_id: UserID) -> None: async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None:
"""Verify that the given user is an admin user """Verify that the given user is an admin user
Args: Args:
auth: api.auth.Auth singleton auth: Auth singleton
user_id: user to check user_id: user to check
Raises: Raises:

View file

@ -17,10 +17,9 @@
import logging import logging
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import ( from synapse.rest.admin._base import (
admin_patterns, admin_patterns,
assert_requester_is_admin, assert_requester_is_admin,
@ -50,7 +49,9 @@ class QuarantineMediaInRoom(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: async def on_POST(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
@ -75,7 +76,9 @@ class QuarantineMediaByUser(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]: async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
@ -103,7 +106,7 @@ class QuarantineMediaByID(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_POST( async def on_POST(
self, request: Request, server_name: str, media_id: str self, request: SynapseRequest, server_name: str, media_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
@ -127,7 +130,9 @@ class ProtectMediaByID(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]: async def on_POST(
self, request: SynapseRequest, media_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
@ -148,7 +153,9 @@ class ListMediaInRoom(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
is_admin = await self.auth.is_server_admin(requester.user) is_admin = await self.auth.is_server_admin(requester.user)
if not is_admin: if not is_admin:
@ -166,7 +173,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
self.media_repository = hs.get_media_repository() self.media_repository = hs.get_media_repository()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True) before_ts = parse_integer(request, "before_ts", required=True)
@ -189,7 +196,7 @@ class DeleteMediaByID(RestServlet):
self.media_repository = hs.get_media_repository() self.media_repository = hs.get_media_repository()
async def on_DELETE( async def on_DELETE(
self, request: Request, server_name: str, media_id: str self, request: SynapseRequest, server_name: str, media_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
@ -218,7 +225,9 @@ class DeleteMediaByDateSize(RestServlet):
self.server_name = hs.hostname self.server_name = hs.hostname
self.media_repository = hs.get_media_repository() self.media_repository = hs.get_media_repository()
async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]: async def on_POST(
self, request: SynapseRequest, server_name: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True) before_ts = parse_integer(request, "before_ts", required=True)

View file

@ -32,6 +32,7 @@ from synapse.http.servlet import (
assert_params_in_dict, assert_params_in_dict,
parse_json_object_from_request, parse_json_object_from_request,
) )
from synapse.http.site import SynapseRequest
from synapse.types import GroupID, JsonDict from synapse.types import GroupID, JsonDict
from ._base import client_patterns from ._base import client_patterns
@ -70,7 +71,9 @@ class GroupServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: async def on_GET(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -81,7 +84,9 @@ class GroupServlet(RestServlet):
return 200, group_description return 200, group_description
@_validate_group_id @_validate_group_id
async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: async def on_POST(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -111,7 +116,9 @@ class GroupSummaryServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: async def on_GET(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -144,7 +151,11 @@ class GroupSummaryRoomsCatServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_PUT( async def on_PUT(
self, request: Request, group_id: str, category_id: Optional[str], room_id: str self,
request: SynapseRequest,
group_id: str,
category_id: Optional[str],
room_id: str,
): ):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -176,7 +187,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_DELETE( async def on_DELETE(
self, request: Request, group_id: str, category_id: str, room_id: str self, request: SynapseRequest, group_id: str, category_id: str, room_id: str
): ):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -206,7 +217,7 @@ class GroupCategoryServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_GET( async def on_GET(
self, request: Request, group_id: str, category_id: str self, request: SynapseRequest, group_id: str, category_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -219,7 +230,7 @@ class GroupCategoryServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_PUT( async def on_PUT(
self, request: Request, group_id: str, category_id: str self, request: SynapseRequest, group_id: str, category_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -247,7 +258,7 @@ class GroupCategoryServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_DELETE( async def on_DELETE(
self, request: Request, group_id: str, category_id: str self, request: SynapseRequest, group_id: str, category_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -274,7 +285,9 @@ class GroupCategoriesServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: async def on_GET(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -298,7 +311,7 @@ class GroupRoleServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_GET( async def on_GET(
self, request: Request, group_id: str, role_id: str self, request: SynapseRequest, group_id: str, role_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -311,7 +324,7 @@ class GroupRoleServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_PUT( async def on_PUT(
self, request: Request, group_id: str, role_id: str self, request: SynapseRequest, group_id: str, role_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -339,7 +352,7 @@ class GroupRoleServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_DELETE( async def on_DELETE(
self, request: Request, group_id: str, role_id: str self, request: SynapseRequest, group_id: str, role_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -366,7 +379,9 @@ class GroupRolesServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: async def on_GET(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -399,7 +414,11 @@ class GroupSummaryUsersRoleServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_PUT( async def on_PUT(
self, request: Request, group_id: str, role_id: Optional[str], user_id: str self,
request: SynapseRequest,
group_id: str,
role_id: Optional[str],
user_id: str,
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -431,7 +450,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_DELETE( async def on_DELETE(
self, request: Request, group_id: str, role_id: str, user_id: str self, request: SynapseRequest, group_id: str, role_id: str, user_id: str
): ):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -458,7 +477,9 @@ class GroupRoomServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: async def on_GET(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -481,7 +502,9 @@ class GroupUsersServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: async def on_GET(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -504,7 +527,9 @@ class GroupInvitedUsersServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: async def on_GET(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -526,7 +551,9 @@ class GroupSettingJoinPolicyServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: async def on_PUT(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -554,7 +581,7 @@ class GroupCreateServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
self.server_name = hs.hostname self.server_name = hs.hostname
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -598,7 +625,7 @@ class GroupAdminRoomsServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_PUT( async def on_PUT(
self, request: Request, group_id: str, room_id: str self, request: SynapseRequest, group_id: str, room_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -615,7 +642,7 @@ class GroupAdminRoomsServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_DELETE( async def on_DELETE(
self, request: Request, group_id: str, room_id: str self, request: SynapseRequest, group_id: str, room_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -646,7 +673,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_PUT( async def on_PUT(
self, request: Request, group_id: str, room_id: str, config_key: str self, request: SynapseRequest, group_id: str, room_id: str, config_key: str
): ):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -678,7 +705,9 @@ class GroupAdminUsersInviteServlet(RestServlet):
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
@_validate_group_id @_validate_group_id
async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]: async def on_PUT(
self, request: SynapseRequest, group_id, user_id
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -708,7 +737,9 @@ class GroupAdminUsersKickServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]: async def on_PUT(
self, request: SynapseRequest, group_id, user_id
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -735,7 +766,9 @@ class GroupSelfLeaveServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: async def on_PUT(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -762,7 +795,9 @@ class GroupSelfJoinServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: async def on_PUT(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -789,7 +824,9 @@ class GroupSelfAcceptInviteServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: async def on_PUT(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -816,7 +853,9 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@_validate_group_id @_validate_group_id
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: async def on_PUT(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -839,7 +878,9 @@ class PublicisedGroupsForUserServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]: async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.get_user_by_req(request, allow_guest=True)
result = await self.groups_handler.get_publicised_groups_for_user(user_id) result = await self.groups_handler.get_publicised_groups_for_user(user_id)
@ -859,7 +900,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -881,7 +922,7 @@ class GroupsForUserServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()

View file

@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
from twisted.web.server import Request from twisted.web.server import Request
from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.site import SynapseRequest
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer from synapse.app.homeserver import HomeServer
@ -35,7 +36,7 @@ class MediaConfigResource(DirectServeJsonResource):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.max_upload_size} self.limits_dict = {"m.upload.size": config.max_upload_size}
async def _async_render_GET(self, request: Request) -> None: async def _async_render_GET(self, request: SynapseRequest) -> None:
await self.auth.get_user_by_req(request) await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True) respond_with_json(request, 200, self.limits_dict, send_cors=True)

View file

@ -39,6 +39,7 @@ from synapse.http.server import (
respond_with_json_bytes, respond_with_json_bytes,
) )
from synapse.http.servlet import parse_integer, parse_string from synapse.http.servlet import parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers from synapse.rest.media.v1._base import get_filename_from_headers
@ -185,7 +186,7 @@ class PreviewUrlResource(DirectServeJsonResource):
request.setHeader(b"Allow", b"OPTIONS, GET") request.setHeader(b"Allow", b"OPTIONS, GET")
respond_with_json(request, 200, {}, send_cors=True) respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_GET(self, request: Request) -> None: async def _async_render_GET(self, request: SynapseRequest) -> None:
# XXX: if get_user_by_req fails, what should we do in an async render? # XXX: if get_user_by_req fails, what should we do in an async render?
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)

View file

@ -22,6 +22,7 @@ from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_string from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.media.v1.media_storage import SpamMediaException from synapse.rest.media.v1.media_storage import SpamMediaException
if TYPE_CHECKING: if TYPE_CHECKING:
@ -49,7 +50,7 @@ class UploadResource(DirectServeJsonResource):
async def _async_render_OPTIONS(self, request: Request) -> None: async def _async_render_OPTIONS(self, request: Request) -> None:
respond_with_json(request, 200, {}, send_cors=True) respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_POST(self, request: Request) -> None: async def _async_render_POST(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have # TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point # already been uploaded to a tmp file at this point

View file

@ -351,11 +351,9 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self @cache_in_self
def get_http_client_context_factory(self) -> IPolicyForHTTPS: def get_http_client_context_factory(self) -> IPolicyForHTTPS:
return ( if self.config.use_insecure_ssl_client_just_for_testing_do_not_use:
InsecureInterceptableContextFactory() return InsecureInterceptableContextFactory()
if self.config.use_insecure_ssl_client_just_for_testing_do_not_use return RegularPolicyForHTTPS()
else RegularPolicyForHTTPS()
)
@cache_in_self @cache_in_self
def get_simple_http_client(self) -> SimpleHttpClient: def get_simple_http_client(self) -> SimpleHttpClient:

View file

@ -17,7 +17,7 @@ import mock
from synapse.app.generic_worker import GenericWorkerServer from synapse.app.generic_worker import GenericWorkerServer
from synapse.replication.tcp.commands import FederationAckCommand from synapse.replication.tcp.commands import FederationAckCommand
from synapse.replication.tcp.protocol import AbstractConnection from synapse.replication.tcp.protocol import IReplicationConnection
from synapse.replication.tcp.streams.federation import FederationStream from synapse.replication.tcp.streams.federation import FederationStream
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -51,8 +51,10 @@ class FederationAckTestCase(HomeserverTestCase):
""" """
rch = self.hs.get_tcp_replication() rch = self.hs.get_tcp_replication()
# wire up the ReplicationCommandHandler to a mock connection # wire up the ReplicationCommandHandler to a mock connection, which needs
mock_connection = mock.Mock(spec=AbstractConnection) # to implement IReplicationConnection. (Note that Mock doesn't understand
# interfaces, but casing an interface to a list gives the attributes.)
mock_connection = mock.Mock(spec=list(IReplicationConnection))
rch.new_connection(mock_connection) rch.new_connection(mock_connection)
# tell it it received an RDATA row # tell it it received an RDATA row