mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-22 01:25:44 +03:00
Fix remaining mypy issues due to Twisted upgrade. (#9608)
This commit is contained in:
parent
026503fa3b
commit
d29b71aa50
8 changed files with 42 additions and 34 deletions
1
changelog.d/9608.misc
Normal file
1
changelog.d/9608.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix incorrect type hints.
|
|
@ -19,7 +19,7 @@ from typing import Any, List, Optional, Type, Union
|
||||||
|
|
||||||
from twisted.internet import protocol
|
from twisted.internet import protocol
|
||||||
|
|
||||||
class RedisProtocol:
|
class RedisProtocol(protocol.Protocol):
|
||||||
def publish(self, channel: str, message: bytes): ...
|
def publish(self, channel: str, message: bytes): ...
|
||||||
async def ping(self) -> None: ...
|
async def ping(self) -> None: ...
|
||||||
async def set(
|
async def set(
|
||||||
|
|
|
@ -45,7 +45,9 @@ from twisted.internet.interfaces import (
|
||||||
IHostResolution,
|
IHostResolution,
|
||||||
IReactorPluggableNameResolver,
|
IReactorPluggableNameResolver,
|
||||||
IResolutionReceiver,
|
IResolutionReceiver,
|
||||||
|
ITCPTransport,
|
||||||
)
|
)
|
||||||
|
from twisted.internet.protocol import connectionDone
|
||||||
from twisted.internet.task import Cooperator
|
from twisted.internet.task import Cooperator
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.web._newclient import ResponseDone
|
from twisted.web._newclient import ResponseDone
|
||||||
|
@ -760,6 +762,8 @@ class BodyExceededMaxSize(Exception):
|
||||||
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
|
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
|
||||||
"""A protocol which immediately errors upon receiving data."""
|
"""A protocol which immediately errors upon receiving data."""
|
||||||
|
|
||||||
|
transport = None # type: Optional[ITCPTransport]
|
||||||
|
|
||||||
def __init__(self, deferred: defer.Deferred):
|
def __init__(self, deferred: defer.Deferred):
|
||||||
self.deferred = deferred
|
self.deferred = deferred
|
||||||
|
|
||||||
|
@ -771,18 +775,21 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
|
||||||
self.deferred.errback(BodyExceededMaxSize())
|
self.deferred.errback(BodyExceededMaxSize())
|
||||||
# Close the connection (forcefully) since all the data will get
|
# Close the connection (forcefully) since all the data will get
|
||||||
# discarded anyway.
|
# discarded anyway.
|
||||||
|
assert self.transport is not None
|
||||||
self.transport.abortConnection()
|
self.transport.abortConnection()
|
||||||
|
|
||||||
def dataReceived(self, data: bytes) -> None:
|
def dataReceived(self, data: bytes) -> None:
|
||||||
self._maybe_fail()
|
self._maybe_fail()
|
||||||
|
|
||||||
def connectionLost(self, reason: Failure) -> None:
|
def connectionLost(self, reason: Failure = connectionDone) -> None:
|
||||||
self._maybe_fail()
|
self._maybe_fail()
|
||||||
|
|
||||||
|
|
||||||
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
|
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
|
||||||
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
|
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
|
||||||
|
|
||||||
|
transport = None # type: Optional[ITCPTransport]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
|
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
|
||||||
):
|
):
|
||||||
|
@ -805,9 +812,10 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
|
||||||
self.deferred.errback(BodyExceededMaxSize())
|
self.deferred.errback(BodyExceededMaxSize())
|
||||||
# Close the connection (forcefully) since all the data will get
|
# Close the connection (forcefully) since all the data will get
|
||||||
# discarded anyway.
|
# discarded anyway.
|
||||||
|
assert self.transport is not None
|
||||||
self.transport.abortConnection()
|
self.transport.abortConnection()
|
||||||
|
|
||||||
def connectionLost(self, reason: Failure) -> None:
|
def connectionLost(self, reason: Failure = connectionDone) -> None:
|
||||||
# If the maximum size was already exceeded, there's nothing to do.
|
# If the maximum size was already exceeded, there's nothing to do.
|
||||||
if self.deferred.called:
|
if self.deferred.called:
|
||||||
return
|
return
|
||||||
|
|
|
@ -302,7 +302,7 @@ class ReplicationCommandHandler:
|
||||||
hs, outbound_redis_connection
|
hs, outbound_redis_connection
|
||||||
)
|
)
|
||||||
hs.get_reactor().connectTCP(
|
hs.get_reactor().connectTCP(
|
||||||
hs.config.redis.redis_host,
|
hs.config.redis.redis_host.encode(),
|
||||||
hs.config.redis.redis_port,
|
hs.config.redis.redis_port,
|
||||||
self._factory,
|
self._factory,
|
||||||
)
|
)
|
||||||
|
@ -311,7 +311,7 @@ class ReplicationCommandHandler:
|
||||||
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
|
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
|
||||||
host = hs.config.worker_replication_host
|
host = hs.config.worker_replication_host
|
||||||
port = hs.config.worker_replication_port
|
port = hs.config.worker_replication_port
|
||||||
hs.get_reactor().connectTCP(host, port, self._factory)
|
hs.get_reactor().connectTCP(host.encode(), port, self._factory)
|
||||||
|
|
||||||
def get_streams(self) -> Dict[str, Stream]:
|
def get_streams(self) -> Dict[str, Stream]:
|
||||||
"""Get a map from stream name to all streams."""
|
"""Get a map from stream name to all streams."""
|
||||||
|
|
|
@ -56,6 +56,7 @@ from prometheus_client import Counter
|
||||||
from zope.interface import Interface, implementer
|
from zope.interface import Interface, implementer
|
||||||
|
|
||||||
from twisted.internet import task
|
from twisted.internet import task
|
||||||
|
from twisted.internet.tcp import Connection
|
||||||
from twisted.protocols.basic import LineOnlyReceiver
|
from twisted.protocols.basic import LineOnlyReceiver
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
|
|
||||||
|
@ -145,6 +146,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||||
(if they send a `PING` command)
|
(if they send a `PING` command)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# The transport is going to be an ITCPTransport, but that doesn't have the
|
||||||
|
# (un)registerProducer methods, those are only on the implementation.
|
||||||
|
transport = None # type: Connection
|
||||||
|
|
||||||
delimiter = b"\n"
|
delimiter = b"\n"
|
||||||
|
|
||||||
# Valid commands we expect to receive
|
# Valid commands we expect to receive
|
||||||
|
@ -189,6 +194,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||||
|
|
||||||
connected_connections.append(self) # Register connection for metrics
|
connected_connections.append(self) # Register connection for metrics
|
||||||
|
|
||||||
|
assert self.transport is not None
|
||||||
self.transport.registerProducer(self, True) # For the *Producing callbacks
|
self.transport.registerProducer(self, True) # For the *Producing callbacks
|
||||||
|
|
||||||
self._send_pending_commands()
|
self._send_pending_commands()
|
||||||
|
@ -213,6 +219,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||||
logger.info(
|
logger.info(
|
||||||
"[%s] Failed to close connection gracefully, aborting", self.id()
|
"[%s] Failed to close connection gracefully, aborting", self.id()
|
||||||
)
|
)
|
||||||
|
assert self.transport is not None
|
||||||
self.transport.abortConnection()
|
self.transport.abortConnection()
|
||||||
else:
|
else:
|
||||||
if now - self.last_sent_command >= PING_TIME:
|
if now - self.last_sent_command >= PING_TIME:
|
||||||
|
@ -302,6 +309,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||||
def close(self):
|
def close(self):
|
||||||
logger.warning("[%s] Closing connection", self.id())
|
logger.warning("[%s] Closing connection", self.id())
|
||||||
self.time_we_closed = self.clock.time_msec()
|
self.time_we_closed = self.clock.time_msec()
|
||||||
|
assert self.transport is not None
|
||||||
self.transport.loseConnection()
|
self.transport.loseConnection()
|
||||||
self.on_connection_closed()
|
self.on_connection_closed()
|
||||||
|
|
||||||
|
@ -399,6 +407,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||||
def connectionLost(self, reason):
|
def connectionLost(self, reason):
|
||||||
logger.info("[%s] Replication connection closed: %r", self.id(), reason)
|
logger.info("[%s] Replication connection closed: %r", self.id(), reason)
|
||||||
if isinstance(reason, Failure):
|
if isinstance(reason, Failure):
|
||||||
|
assert reason.type is not None
|
||||||
connection_close_counter.labels(reason.type.__name__).inc()
|
connection_close_counter.labels(reason.type.__name__).inc()
|
||||||
else:
|
else:
|
||||||
connection_close_counter.labels(reason.__class__.__name__).inc()
|
connection_close_counter.labels(reason.__class__.__name__).inc()
|
||||||
|
|
|
@ -365,6 +365,6 @@ def lazyConnection(
|
||||||
factory.continueTrying = reconnect
|
factory.continueTrying = reconnect
|
||||||
|
|
||||||
reactor = hs.get_reactor()
|
reactor = hs.get_reactor()
|
||||||
reactor.connectTCP(host, port, factory, timeout=30, bindAddress=None)
|
reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None)
|
||||||
|
|
||||||
return factory.handler
|
return factory.handler
|
||||||
|
|
|
@ -13,9 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import attr
|
|
||||||
|
|
||||||
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
|
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
|
||||||
from twisted.internet.protocol import Protocol
|
from twisted.internet.protocol import Protocol
|
||||||
|
@ -158,10 +156,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
# Set up client side protocol
|
# Set up client side protocol
|
||||||
client_protocol = client_factory.buildProtocol(None)
|
client_protocol = client_factory.buildProtocol(None)
|
||||||
|
|
||||||
request_factory = OneShotRequestFactory()
|
|
||||||
|
|
||||||
# Set up the server side protocol
|
# Set up the server side protocol
|
||||||
channel = _PushHTTPChannel(self.reactor, request_factory, self.site)
|
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)
|
||||||
|
|
||||||
# Connect client to server and vice versa.
|
# Connect client to server and vice versa.
|
||||||
client_to_server_transport = FakeTransport(
|
client_to_server_transport = FakeTransport(
|
||||||
|
@ -183,7 +179,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
server_to_client_transport.loseConnection()
|
server_to_client_transport.loseConnection()
|
||||||
client_to_server_transport.loseConnection()
|
client_to_server_transport.loseConnection()
|
||||||
|
|
||||||
return request_factory.request
|
return channel.request
|
||||||
|
|
||||||
def assert_request_is_get_repl_stream_updates(
|
def assert_request_is_get_repl_stream_updates(
|
||||||
self, request: SynapseRequest, stream_name: str
|
self, request: SynapseRequest, stream_name: str
|
||||||
|
@ -237,7 +233,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||||
if self.hs.config.redis.redis_enabled:
|
if self.hs.config.redis.redis_enabled:
|
||||||
# Handle attempts to connect to fake redis server.
|
# Handle attempts to connect to fake redis server.
|
||||||
self.reactor.add_tcp_client_callback(
|
self.reactor.add_tcp_client_callback(
|
||||||
"localhost",
|
b"localhost",
|
||||||
6379,
|
6379,
|
||||||
self.connect_any_redis_attempts,
|
self.connect_any_redis_attempts,
|
||||||
)
|
)
|
||||||
|
@ -392,10 +388,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||||
# Set up client side protocol
|
# Set up client side protocol
|
||||||
client_protocol = client_factory.buildProtocol(None)
|
client_protocol = client_factory.buildProtocol(None)
|
||||||
|
|
||||||
request_factory = OneShotRequestFactory()
|
|
||||||
|
|
||||||
# Set up the server side protocol
|
# Set up the server side protocol
|
||||||
channel = _PushHTTPChannel(self.reactor, request_factory, self._hs_to_site[hs])
|
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])
|
||||||
|
|
||||||
# Connect client to server and vice versa.
|
# Connect client to server and vice versa.
|
||||||
client_to_server_transport = FakeTransport(
|
client_to_server_transport = FakeTransport(
|
||||||
|
@ -421,7 +415,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||||
clients = self.reactor.tcpClients
|
clients = self.reactor.tcpClients
|
||||||
while clients:
|
while clients:
|
||||||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
|
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
|
||||||
self.assertEqual(host, "localhost")
|
self.assertEqual(host, b"localhost")
|
||||||
self.assertEqual(port, 6379)
|
self.assertEqual(port, 6379)
|
||||||
|
|
||||||
client_protocol = client_factory.buildProtocol(None)
|
client_protocol = client_factory.buildProtocol(None)
|
||||||
|
@ -453,21 +447,6 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler):
|
||||||
self.received_rdata_rows.append((stream_name, token, r))
|
self.received_rdata_rows.append((stream_name, token, r))
|
||||||
|
|
||||||
|
|
||||||
@attr.s()
|
|
||||||
class OneShotRequestFactory:
|
|
||||||
"""A simple request factory that generates a single `SynapseRequest` and
|
|
||||||
stores it for future use. Can only be used once.
|
|
||||||
"""
|
|
||||||
|
|
||||||
request = attr.ib(default=None)
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
assert self.request is None
|
|
||||||
|
|
||||||
self.request = SynapseRequest(*args, **kwargs)
|
|
||||||
return self.request
|
|
||||||
|
|
||||||
|
|
||||||
class _PushHTTPChannel(HTTPChannel):
|
class _PushHTTPChannel(HTTPChannel):
|
||||||
"""A HTTPChannel that wraps pull producers to push producers.
|
"""A HTTPChannel that wraps pull producers to push producers.
|
||||||
|
|
||||||
|
@ -479,7 +458,7 @@ class _PushHTTPChannel(HTTPChannel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, reactor: IReactorTime, request_factory: Callable[..., Request], site: Site
|
self, reactor: IReactorTime, request_factory: Type[Request], site: Site
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.reactor = reactor
|
self.reactor = reactor
|
||||||
|
@ -510,6 +489,11 @@ class _PushHTTPChannel(HTTPChannel):
|
||||||
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
|
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def requestDone(self, request):
|
||||||
|
# Store the request for inspection.
|
||||||
|
self.request = request
|
||||||
|
super().requestDone(request)
|
||||||
|
|
||||||
|
|
||||||
class _PullToPushProducer:
|
class _PullToPushProducer:
|
||||||
"""A push producer that wraps a pull producer."""
|
"""A push producer that wraps a pull producer."""
|
||||||
|
@ -597,6 +581,8 @@ class FakeRedisPubSubServer:
|
||||||
class FakeRedisPubSubProtocol(Protocol):
|
class FakeRedisPubSubProtocol(Protocol):
|
||||||
"""A connection from a client talking to the fake Redis server."""
|
"""A connection from a client talking to the fake Redis server."""
|
||||||
|
|
||||||
|
transport = None # type: Optional[FakeTransport]
|
||||||
|
|
||||||
def __init__(self, server: FakeRedisPubSubServer):
|
def __init__(self, server: FakeRedisPubSubServer):
|
||||||
self._server = server
|
self._server = server
|
||||||
self._reader = hiredis.Reader()
|
self._reader = hiredis.Reader()
|
||||||
|
@ -641,6 +627,8 @@ class FakeRedisPubSubProtocol(Protocol):
|
||||||
|
|
||||||
def send(self, msg):
|
def send(self, msg):
|
||||||
"""Send a message back to the client."""
|
"""Send a message back to the client."""
|
||||||
|
assert self.transport is not None
|
||||||
|
|
||||||
raw = self.encode(msg).encode("utf-8")
|
raw = self.encode(msg).encode("utf-8")
|
||||||
|
|
||||||
self.transport.write(raw)
|
self.transport.write(raw)
|
||||||
|
|
|
@ -16,6 +16,7 @@ from twisted.internet.interfaces import (
|
||||||
IReactorPluggableNameResolver,
|
IReactorPluggableNameResolver,
|
||||||
IReactorTCP,
|
IReactorTCP,
|
||||||
IResolverSimple,
|
IResolverSimple,
|
||||||
|
ITransport,
|
||||||
)
|
)
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
|
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
|
||||||
|
@ -467,6 +468,7 @@ def get_clock():
|
||||||
return clock, hs_clock
|
return clock, hs_clock
|
||||||
|
|
||||||
|
|
||||||
|
@implementer(ITransport)
|
||||||
@attr.s(cmp=False)
|
@attr.s(cmp=False)
|
||||||
class FakeTransport:
|
class FakeTransport:
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in a new issue