Lay some foundation work to allow workers to only subscribe to some kinds of messages, reducing replication traffic. (#12672)

This commit is contained in:
reivilibre 2022-05-19 16:29:08 +01:00 committed by GitHub
parent eb4aaa1b4b
commit 177b884ad7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 173 additions and 24 deletions

1
changelog.d/12672.misc Normal file
View file

@ -0,0 +1 @@
Lay some foundation work to allow workers to only subscribe to some kinds of messages, reducing replication traffic.

View file

@ -1,5 +1,5 @@
# Copyright 2017 Vector Creations Ltd # Copyright 2017 Vector Creations Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright 2020, 2022 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -101,6 +101,9 @@ class ReplicationCommandHandler:
self._instance_id = hs.get_instance_id() self._instance_id = hs.get_instance_id()
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
# Additional Redis channel suffixes to subscribe to.
self._channels_to_subscribe_to: List[str] = []
self._is_presence_writer = ( self._is_presence_writer = (
hs.get_instance_name() in hs.config.worker.writers.presence hs.get_instance_name() in hs.config.worker.writers.presence
) )
@ -243,6 +246,31 @@ class ReplicationCommandHandler:
# If we're NOT using Redis, this must be handled by the master # If we're NOT using Redis, this must be handled by the master
self._should_insert_client_ips = hs.get_instance_name() == "master" self._should_insert_client_ips = hs.get_instance_name() == "master"
if self._is_master or self._should_insert_client_ips:
self.subscribe_to_channel("USER_IP")
def subscribe_to_channel(self, channel_name: str) -> None:
"""
Indicates that we wish to subscribe to a Redis channel by name.
(The name will later be prefixed with the server name; i.e. subscribing
to the 'ABC' channel actually subscribes to 'example.com/ABC' Redis-side.)
Raises:
- If replication has already started, then it's too late to subscribe
to new channels.
"""
if self._factory is not None:
# We don't allow subscribing after the fact to avoid the chance
# of missing an important message because we didn't subscribe in time.
raise RuntimeError(
"Cannot subscribe to more channels after replication started."
)
if channel_name not in self._channels_to_subscribe_to:
self._channels_to_subscribe_to.append(channel_name)
def _add_command_to_stream_queue( def _add_command_to_stream_queue(
self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand] self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None: ) -> None:
@ -321,7 +349,9 @@ class ReplicationCommandHandler:
# Now create the factory/connection for the subscription stream. # Now create the factory/connection for the subscription stream.
self._factory = RedisDirectTcpReplicationClientFactory( self._factory = RedisDirectTcpReplicationClientFactory(
hs, outbound_redis_connection hs,
outbound_redis_connection,
channel_names=self._channels_to_subscribe_to,
) )
hs.get_reactor().connectTCP( hs.get_reactor().connectTCP(
hs.config.redis.redis_host, hs.config.redis.redis_host,

View file

@ -14,7 +14,7 @@
import logging import logging
from inspect import isawaitable from inspect import isawaitable
from typing import TYPE_CHECKING, Any, Generic, Optional, Type, TypeVar, cast from typing import TYPE_CHECKING, Any, Generic, List, Optional, Type, TypeVar, cast
import attr import attr
import txredisapi import txredisapi
@ -85,14 +85,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
Attributes: Attributes:
synapse_handler: The command handler to handle incoming commands. synapse_handler: The command handler to handle incoming commands.
synapse_stream_name: The *redis* stream name to subscribe to and publish synapse_stream_prefix: The *redis* stream name to subscribe to and publish
from (not anything to do with Synapse replication streams). from (not anything to do with Synapse replication streams).
synapse_outbound_redis_connection: The connection to redis to use to send synapse_outbound_redis_connection: The connection to redis to use to send
commands. commands.
""" """
synapse_handler: "ReplicationCommandHandler" synapse_handler: "ReplicationCommandHandler"
synapse_stream_name: str synapse_stream_prefix: str
synapse_channel_names: List[str]
synapse_outbound_redis_connection: txredisapi.ConnectionHandler synapse_outbound_redis_connection: txredisapi.ConnectionHandler
def __init__(self, *args: Any, **kwargs: Any): def __init__(self, *args: Any, **kwargs: Any):
@ -117,8 +118,13 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
# it's important to make sure that we only send the REPLICATE command once we # it's important to make sure that we only send the REPLICATE command once we
# have successfully subscribed to the stream - otherwise we might miss the # have successfully subscribed to the stream - otherwise we might miss the
# POSITION response sent back by the other end. # POSITION response sent back by the other end.
logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name) fully_qualified_stream_names = [
await make_deferred_yieldable(self.subscribe(self.synapse_stream_name)) f"{self.synapse_stream_prefix}/{stream_suffix}"
for stream_suffix in self.synapse_channel_names
] + [self.synapse_stream_prefix]
logger.info("Sending redis SUBSCRIBE for %r", fully_qualified_stream_names)
await make_deferred_yieldable(self.subscribe(fully_qualified_stream_names))
logger.info( logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command" "Successfully subscribed to redis stream, sending REPLICATE command"
) )
@ -217,7 +223,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
await make_deferred_yieldable( await make_deferred_yieldable(
self.synapse_outbound_redis_connection.publish( self.synapse_outbound_redis_connection.publish(
self.synapse_stream_name, encoded_string self.synapse_stream_prefix, encoded_string
) )
) )
@ -300,20 +306,27 @@ def format_address(address: IAddress) -> str:
class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
"""This is a reconnecting factory that connects to redis and immediately """This is a reconnecting factory that connects to redis and immediately
subscribes to a stream. subscribes to some streams.
Args: Args:
hs hs
outbound_redis_connection: A connection to redis that will be used to outbound_redis_connection: A connection to redis that will be used to
send outbound commands (this is separate to the redis connection send outbound commands (this is separate to the redis connection
used to subscribe). used to subscribe).
channel_names: A list of channel names to append to the base channel name
to additionally subscribe to.
e.g. if ['ABC', 'DEF'] is specified then we'll listen to:
example.com; example.com/ABC; and example.com/DEF.
""" """
maxDelay = 5 maxDelay = 5
protocol = RedisSubscriber protocol = RedisSubscriber
def __init__( def __init__(
self, hs: "HomeServer", outbound_redis_connection: txredisapi.ConnectionHandler self,
hs: "HomeServer",
outbound_redis_connection: txredisapi.ConnectionHandler,
channel_names: List[str],
): ):
super().__init__( super().__init__(
@ -326,7 +339,8 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
) )
self.synapse_handler = hs.get_replication_command_handler() self.synapse_handler = hs.get_replication_command_handler()
self.synapse_stream_name = hs.hostname self.synapse_stream_prefix = hs.hostname
self.synapse_channel_names = channel_names
self.synapse_outbound_redis_connection = outbound_redis_connection self.synapse_outbound_redis_connection = outbound_redis_connection
@ -340,7 +354,8 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
# protocol. # protocol.
p.synapse_handler = self.synapse_handler p.synapse_handler = self.synapse_handler
p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
p.synapse_stream_name = self.synapse_stream_name p.synapse_stream_prefix = self.synapse_stream_prefix
p.synapse_channel_names = self.synapse_channel_names
return p return p

View file

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Dict, List, Optional, Tuple from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple
from twisted.internet.address import IPv4Address from twisted.internet.address import IPv4Address
from twisted.internet.protocol import Protocol from twisted.internet.protocol import Protocol
@ -32,6 +33,7 @@ from synapse.server import HomeServer
from tests import unittest from tests import unittest
from tests.server import FakeTransport from tests.server import FakeTransport
from tests.utils import USE_POSTGRES_FOR_TESTS
try: try:
import hiredis import hiredis
@ -475,22 +477,25 @@ class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub.""" """A fake Redis server for pub/sub."""
def __init__(self): def __init__(self):
self._subscribers = set() self._subscribers_by_channel: Dict[
bytes, Set["FakeRedisPubSubProtocol"]
] = defaultdict(set)
def add_subscriber(self, conn): def add_subscriber(self, conn, channel: bytes):
"""A connection has called SUBSCRIBE""" """A connection has called SUBSCRIBE"""
self._subscribers.add(conn) self._subscribers_by_channel[channel].add(conn)
def remove_subscriber(self, conn): def remove_subscriber(self, conn):
"""A connection has called UNSUBSCRIBE""" """A connection has lost connection"""
self._subscribers.discard(conn) for subscribers in self._subscribers_by_channel.values():
subscribers.discard(conn)
def publish(self, conn, channel, msg) -> int: def publish(self, conn, channel: bytes, msg) -> int:
"""A connection want to publish a message to subscribers.""" """A connection want to publish a message to subscribers."""
for sub in self._subscribers: for sub in self._subscribers_by_channel[channel]:
sub.send(["message", channel, msg]) sub.send(["message", channel, msg])
return len(self._subscribers) return len(self._subscribers_by_channel)
def buildProtocol(self, addr): def buildProtocol(self, addr):
return FakeRedisPubSubProtocol(self) return FakeRedisPubSubProtocol(self)
@ -531,9 +536,10 @@ class FakeRedisPubSubProtocol(Protocol):
num_subscribers = self._server.publish(self, channel, message) num_subscribers = self._server.publish(self, channel, message)
self.send(num_subscribers) self.send(num_subscribers)
elif command == b"SUBSCRIBE": elif command == b"SUBSCRIBE":
(channel,) = args for idx, channel in enumerate(args):
self._server.add_subscriber(self) num_channels = idx + 1
self.send(["subscribe", channel, 1]) self._server.add_subscriber(self, channel)
self.send(["subscribe", channel, num_channels])
# Since we use SET/GET to cache things we can safely no-op them. # Since we use SET/GET to cache things we can safely no-op them.
elif command == b"SET": elif command == b"SET":
@ -576,3 +582,27 @@ class FakeRedisPubSubProtocol(Protocol):
def connectionLost(self, reason): def connectionLost(self, reason):
self._server.remove_subscriber(self) self._server.remove_subscriber(self)
class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase):
"""
A test case that enables Redis, providing a fake Redis server.
"""
if not hiredis:
skip = "Requires hiredis"
if not USE_POSTGRES_FOR_TESTS:
# Redis replication only takes place on Postgres
skip = "Requires Postgres"
def default_config(self) -> Dict[str, Any]:
"""
Overrides the default config to enable Redis.
Even if the test only uses make_worker_hs, the main process needs Redis
enabled otherwise it won't create a Fake Redis server to listen on the
Redis port and accept fake TCP connections.
"""
base = super().default_config()
base["redis"] = {"enabled": True}
return base

View file

@ -0,0 +1,73 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests.replication._base import RedisMultiWorkerStreamTestCase
class ChannelsTestCase(RedisMultiWorkerStreamTestCase):
def test_subscribed_to_enough_redis_channels(self) -> None:
# The default main process is subscribed to the USER_IP channel.
self.assertCountEqual(
self.hs.get_replication_command_handler()._channels_to_subscribe_to,
["USER_IP"],
)
def test_background_worker_subscribed_to_user_ip(self) -> None:
# The default main process is subscribed to the USER_IP channel.
worker1 = self.make_worker_hs(
"synapse.app.generic_worker",
extra_config={
"worker_name": "worker1",
"run_background_tasks_on": "worker1",
"redis": {"enabled": True},
},
)
self.assertIn(
"USER_IP",
worker1.get_replication_command_handler()._channels_to_subscribe_to,
)
# Advance so the Redis subscription gets processed
self.pump(0.1)
# The counts are 2 because both the main process and the worker are subscribed.
self.assertEqual(len(self._redis_server._subscribers_by_channel[b"test"]), 2)
self.assertEqual(
len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 2
)
def test_non_background_worker_not_subscribed_to_user_ip(self) -> None:
# The default main process is subscribed to the USER_IP channel.
worker2 = self.make_worker_hs(
"synapse.app.generic_worker",
extra_config={
"worker_name": "worker2",
"run_background_tasks_on": "worker1",
"redis": {"enabled": True},
},
)
self.assertNotIn(
"USER_IP",
worker2.get_replication_command_handler()._channels_to_subscribe_to,
)
# Advance so the Redis subscription gets processed
self.pump(0.1)
# The count is 2 because both the main process and the worker are subscribed.
self.assertEqual(len(self._redis_server._subscribers_by_channel[b"test"]), 2)
# For USER_IP, the count is 1 because only the main process is subscribed.
self.assertEqual(
len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 1
)