Pass str to twisted's IReactorTCP (#10895)

This follows a correction made in twisted/twisted#1664 and should fix our Twisted Trial CI job.

Until that change is in a twisted release, we'll have to ignore the type
of the `host` argument. I've raised #10899 to remind us to review the
issue in a few months' time.
This commit is contained in:
David Robertson 2021-09-30 12:51:47 +01:00 committed by GitHub
parent 3aefc7b66d
commit 29364145b2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 27 additions and 11 deletions

1
changelog.d/10895.misc Normal file
View file

@ -0,0 +1 @@
Fix type hints to be compatible with an upcoming change to Twisted.

View file

@ -105,8 +105,13 @@ async def _sendmail(
# set to enable TLS. # set to enable TLS.
factory = build_sender_factory(hostname=smtphost if enable_tls else None) factory = build_sender_factory(hostname=smtphost if enable_tls else None)
# the IReactorTCP interface claims host has to be a bytes, which seems to be wrong reactor.connectTCP(
reactor.connectTCP(smtphost, smtpport, factory, timeout=30, bindAddress=None) # type: ignore[arg-type] smtphost, # type: ignore[arg-type]
smtpport,
factory,
timeout=30,
bindAddress=None,
)
await make_deferred_yieldable(d) await make_deferred_yieldable(d)

View file

@ -315,7 +315,7 @@ class ReplicationCommandHandler:
hs, outbound_redis_connection hs, outbound_redis_connection
) )
hs.get_reactor().connectTCP( hs.get_reactor().connectTCP(
hs.config.redis.redis_host.encode(), hs.config.redis.redis_host, # type: ignore[arg-type]
hs.config.redis.redis_port, hs.config.redis.redis_port,
self._factory, self._factory,
) )
@ -324,7 +324,11 @@ class ReplicationCommandHandler:
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
host = hs.config.worker.worker_replication_host host = hs.config.worker.worker_replication_host
port = hs.config.worker.worker_replication_port port = hs.config.worker.worker_replication_port
hs.get_reactor().connectTCP(host.encode(), port, self._factory) hs.get_reactor().connectTCP(
host, # type: ignore[arg-type]
port,
self._factory,
)
def get_streams(self) -> Dict[str, Stream]: def get_streams(self) -> Dict[str, Stream]:
"""Get a map from stream name to all streams.""" """Get a map from stream name to all streams."""

View file

@ -364,6 +364,12 @@ def lazyConnection(
factory.continueTrying = reconnect factory.continueTrying = reconnect
reactor = hs.get_reactor() reactor = hs.get_reactor()
reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None) reactor.connectTCP(
host, # type: ignore[arg-type]
port,
factory,
timeout=30,
bindAddress=None,
)
return factory.handler return factory.handler

View file

@ -240,7 +240,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
if self.hs.config.redis.redis_enabled: if self.hs.config.redis.redis_enabled:
# Handle attempts to connect to fake redis server. # Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback( self.reactor.add_tcp_client_callback(
b"localhost", "localhost",
6379, 6379,
self.connect_any_redis_attempts, self.connect_any_redis_attempts,
) )
@ -424,7 +424,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
clients = self.reactor.tcpClients clients = self.reactor.tcpClients
while clients: while clients:
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, b"localhost") self.assertEqual(host, "localhost")
self.assertEqual(port, 6379) self.assertEqual(port, 6379)
client_protocol = client_factory.buildProtocol(None) client_protocol = client_factory.buildProtocol(None)

View file

@ -317,7 +317,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def __init__(self): def __init__(self):
self.threadpool = ThreadPool(self) self.threadpool = ThreadPool(self)
self._tcp_callbacks = {} self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {}
self._udp = [] self._udp = []
self.lookups: Dict[str, str] = {} self.lookups: Dict[str, str] = {}
self._thread_callbacks: Deque[Callable[[], None]] = deque() self._thread_callbacks: Deque[Callable[[], None]] = deque()
@ -355,7 +355,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def getThreadPool(self): def getThreadPool(self):
return self.threadpool return self.threadpool
def add_tcp_client_callback(self, host, port, callback): def add_tcp_client_callback(self, host: str, port: int, callback: Callable):
"""Add a callback that will be invoked when we receive a connection """Add a callback that will be invoked when we receive a connection
attempt to the given IP/port using `connectTCP`. attempt to the given IP/port using `connectTCP`.
@ -364,7 +364,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
""" """
self._tcp_callbacks[(host, port)] = callback self._tcp_callbacks[(host, port)] = callback
def connectTCP(self, host, port, factory, timeout=30, bindAddress=None): def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None):
"""Fake L{IReactorTCP.connectTCP}.""" """Fake L{IReactorTCP.connectTCP}."""
conn = super().connectTCP( conn = super().connectTCP(
@ -475,7 +475,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
return server return server
def get_clock(): def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
clock = ThreadedMemoryReactorClock() clock = ThreadedMemoryReactorClock()
hs_clock = Clock(clock) hs_clock = Clock(clock)
return clock, hs_clock return clock, hs_clock