From ef4f063687a0228eb049a3253f1f38597e066a0a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 4 Mar 2020 16:22:39 +0000 Subject: [PATCH] Move command processing out of transport --- synapse/app/generic_worker.py | 21 +- synapse/replication/tcp/client.py | 191 +-------- synapse/replication/tcp/handler.py | 401 ++++++++++++++++++ synapse/replication/tcp/protocol.py | 286 ++----------- synapse/replication/tcp/resource.py | 165 +------ synapse/server.py | 15 +- tests/replication/slave/storage/_base.py | 22 +- tests/replication/tcp/streams/_base.py | 38 +- .../replication/tcp/streams/test_receipts.py | 1 - 9 files changed, 507 insertions(+), 633 deletions(-) create mode 100644 synapse/replication/tcp/handler.py diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index f125658615..e0fdef5cdb 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -64,8 +64,9 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore -from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.replication.tcp.client import ReplicationClientFactory from synapse.replication.tcp.commands import ClearUserSyncsCommand +from synapse.replication.tcp.handler import WorkerReplicationDataHandler from synapse.replication.tcp.streams import ( AccountDataStream, DeviceListsStream, @@ -598,25 +599,26 @@ class GenericWorkerServer(HomeServer): else: logger.warning("Unrecognized listener type: %s", listener["type"]) - self.get_tcp_replication().start_replication(self) + factory = ReplicationClientFactory(self, self.config.worker_name) + host = self.config.worker_replication_host + port = self.config.worker_replication_port + self.get_reactor().connectTCP(host, port, factory) def remove_pusher(self, app_id, push_key, user_id): self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id) - def build_tcp_replication(self): - return GenericWorkerReplicationHandler(self) - def build_presence_handler(self): return GenericWorkerPresence(self) def build_typing_handler(self): return GenericWorkerTyping(self) + def build_replication_data_handler(self): + return GenericWorkerReplicationHandler(self) -class GenericWorkerReplicationHandler(ReplicationClientHandler): + +class GenericWorkerReplicationHandler(WorkerReplicationDataHandler): def __init__(self, hs): - super(GenericWorkerReplicationHandler, self).__init__(hs.get_datastore()) - self.store = hs.get_datastore() self.typing_handler = hs.get_typing_handler() # NB this is a SynchrotronPresence, not a normal PresenceHandler @@ -644,9 +646,6 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler): args.update(self.send_handler.stream_positions()) return args - def get_currently_syncing_users(self): - return self.presence_handler.get_currently_syncing_users() - async def process_and_notify(self, stream_name, token, rows): try: if self.send_handler: diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e86d9805f1..e60baf2bd5 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -16,26 +16,10 @@ """ import logging -from typing import Dict, List, Optional -from twisted.internet import defer from twisted.internet.protocol import ReconnectingClientFactory -from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.tcp.protocol import ( - AbstractReplicationClientHandler, - ClientReplicationStreamProtocol, -) - -from .commands import ( - Command, - FederationAckCommand, - InvalidateCacheCommand, - RemoteServerUpCommand, - RemovePusherCommand, - UserIpCommand, - UserSyncCommand, -) +from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol logger = logging.getLogger(__name__) @@ -51,9 +35,9 @@ class ReplicationClientFactory(ReconnectingClientFactory): initialDelay = 0.1 maxDelay = 1 # Try at least once every N seconds - def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler): + def __init__(self, hs, client_name): self.client_name = client_name - self.handler = handler + self.handler = hs.get_tcp_replication() self.server_name = hs.config.server_name self.hs = hs self._clock = hs.get_clock() # As self.clock is defined in super class @@ -76,172 +60,3 @@ class ReplicationClientFactory(ReconnectingClientFactory): def clientConnectionFailed(self, connector, reason): logger.error("Failed to connect to replication: %r", reason) ReconnectingClientFactory.clientConnectionFailed(self, connector, reason) - - -class ReplicationClientHandler(AbstractReplicationClientHandler): - """A base handler that can be passed to the ReplicationClientFactory. - - By default proxies incoming replication data to the SlaveStore. - """ - - def __init__(self, store: BaseSlavedStore): - self.store = store - - # The current connection. None if we are currently (re)connecting - self.connection = None - - # Any pending commands to be sent once a new connection has been - # established - self.pending_commands = [] # type: List[Command] - - # Map from string -> deferred, to wake up when receiveing a SYNC with - # the given string. - # Used for tests. - self.awaiting_syncs = {} # type: Dict[str, defer.Deferred] - - # The factory used to create connections. - self.factory = None # type: Optional[ReplicationClientFactory] - - def start_replication(self, hs): - """Helper method to start a replication connection to the remote server - using TCP. - """ - client_name = hs.config.worker_name - self.factory = ReplicationClientFactory(hs, client_name, self) - host = hs.config.worker_replication_host - port = hs.config.worker_replication_port - hs.get_reactor().connectTCP(host, port, self.factory) - - async def on_rdata(self, stream_name, token, rows): - """Called to handle a batch of replication data with a given stream token. - - By default this just pokes the slave store. Can be overridden in subclasses to - handle more. - - Args: - stream_name (str): name of the replication stream for this batch of rows - token (int): stream token for this batch of rows - rows (list): a list of Stream.ROW_TYPE objects as returned by - Stream.parse_row. - """ - logger.debug("Received rdata %s -> %s", stream_name, token) - self.store.process_replication_rows(stream_name, token, rows) - - async def on_position(self, stream_name, token): - """Called when we get new position data. By default this just pokes - the slave store. - - Can be overriden in subclasses to handle more. - """ - self.store.process_replication_rows(stream_name, token, []) - - def on_sync(self, data): - """When we received a SYNC we wake up any deferreds that were waiting - for the sync with the given data. - - Used by tests. - """ - d = self.awaiting_syncs.pop(data, None) - if d: - d.callback(data) - - def on_remote_server_up(self, server: str): - """Called when get a new REMOTE_SERVER_UP command.""" - - def get_streams_to_replicate(self) -> Dict[str, int]: - """Called when a new connection has been established and we need to - subscribe to streams. - - Returns: - map from stream name to the most recent update we have for - that stream (ie, the point we want to start replicating from) - """ - args = self.store.stream_positions() - user_account_data = args.pop("user_account_data", None) - room_account_data = args.pop("room_account_data", None) - if user_account_data: - args["account_data"] = user_account_data - elif room_account_data: - args["account_data"] = room_account_data - - return args - - def get_currently_syncing_users(self): - """Get the list of currently syncing users (if any). This is called - when a connection has been established and we need to send the - currently syncing users. (Overriden by the synchrotron's only) - """ - return [] - - def send_command(self, cmd): - """Send a command to master (when we get establish a connection if we - don't have one already.) - """ - if self.connection: - self.connection.send_command(cmd) - else: - logger.warning("Queuing command as not connected: %r", cmd.NAME) - self.pending_commands.append(cmd) - - def send_federation_ack(self, token): - """Ack data for the federation stream. This allows the master to drop - data stored purely in memory. - """ - self.send_command(FederationAckCommand(token)) - - def send_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms): - """Poke the master that a user has started/stopped syncing. - """ - self.send_command( - UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) - ) - - def send_remove_pusher(self, app_id, push_key, user_id): - """Poke the master to remove a pusher for a user - """ - cmd = RemovePusherCommand(app_id, push_key, user_id) - self.send_command(cmd) - - def send_invalidate_cache(self, cache_func, keys): - """Poke the master to invalidate a cache. - """ - cmd = InvalidateCacheCommand(cache_func.__name__, keys) - self.send_command(cmd) - - def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen): - """Tell the master that the user made a request. - """ - cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) - self.send_command(cmd) - - def send_remote_server_up(self, server: str): - self.send_command(RemoteServerUpCommand(server)) - - def await_sync(self, data): - """Returns a deferred that is resolved when we receive a SYNC command - with given data. - - [Not currently] used by tests. - """ - return self.awaiting_syncs.setdefault(data, defer.Deferred()) - - def update_connection(self, connection): - """Called when a connection has been established (or lost with None). - """ - self.connection = connection - if connection: - for cmd in self.pending_commands: - connection.send_command(cmd) - self.pending_commands = [] - - def finished_connecting(self): - """Called when we have successfully subscribed and caught up to all - streams we're interested in. - """ - logger.info("Finished connecting to server") - - # We don't reset the delay any earlier as otherwise if there is a - # problem during start up we'll end up tight looping connecting to the - # server. - if self.factory: - self.factory.resetDelay() diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py new file mode 100644 index 0000000000..a59ab01471 --- /dev/null +++ b/synapse/replication/tcp/handler.py @@ -0,0 +1,401 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A replication client for use by synapse workers. +""" + +import logging +from typing import Any, Dict, List + +from prometheus_client import Counter + +from synapse.metrics import LaterGauge +from synapse.replication.tcp.commands import ( + ClearUserSyncsCommand, + Command, + FederationAckCommand, + InvalidateCacheCommand, + PositionCommand, + RdataCommand, + RemoteServerUpCommand, + RemovePusherCommand, + ReplicateCommand, + UserIpCommand, + UserSyncCommand, +) +from synapse.replication.tcp.streams import STREAMS_MAP, Stream + +logger = logging.getLogger(__name__) + + +user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") +federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") +remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") +invalidate_cache_counter = Counter( + "synapse_replication_tcp_resource_invalidate_cache", "" +) +user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") + + +class ReplicationClientHandler: + """Handles incoming commands from replication. + + Proxies data to `HomeServer.get_replication_data_handler()`. + """ + + def __init__(self, hs): + self.replication_data_handler = hs.get_replication_data_handler() + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + self.clock = hs.get_clock() + self.presence_handler = hs.get_presence_handler() + self.instance_id = hs.get_instance_id() + + self.connections = [] + + self.streams = { + stream.NAME: stream(hs) for stream in STREAMS_MAP.values() + } # type: Dict[str, Stream] + + LaterGauge( + "synapse_replication_tcp_resource_total_connections", + "", + [], + lambda: len(self.connections), + ) + + LaterGauge( + "synapse_replication_tcp_resource_connections_per_stream", + "", + ["stream_name"], + lambda: { + (stream_name,): len( + [ + conn + for conn in self.connections + if stream_name in conn.replication_streams + ] + ) + for stream_name in self.streams + }, + ) + + # Map of stream to batched updates. See RdataCommand for info on how + # batching works. + self.pending_batches = {} # type: Dict[str, List[Any]] + + self.is_master = hs.config.worker_app is None + + self.federation_sender = None + if self.is_master and not hs.config.send_federation: + self.federation_sender = hs.get_federation_sender() + + self._server_notices_sender = None + if self.is_master: + self._server_notices_sender = hs.get_server_notices_sender() + self.notifier.add_remote_server_up_callback(self.send_remote_server_up) + + def new_connection(self, connection): + self.connections.append(connection) + + def lost_connection(self, connection): + try: + self.connections.remove(connection) + except ValueError: + pass + + def connected(self) -> bool: + """Do we have any replication connections open? + + Used to no-op if nothing is connected. + """ + return bool(self.connections) + + async def on_REPLICATE(self, cmd: ReplicateCommand): + # We only want to announce positions by the writer of the streams. + # Currently this is just the master process. + if not self.is_master: + return + + if not self.connections: + raise Exception("Not connected") + + for stream_name, stream in self.streams.items(): + current_token = stream.current_token() + self.send_command(PositionCommand(stream_name, current_token)) + + async def on_USER_SYNC(self, cmd: UserSyncCommand): + user_sync_counter.inc() + + if self.is_master: + await self.presence_handler.update_external_syncs_row( + cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms + ) + + async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand): + if self.is_master: + await self.presence_handler.update_external_syncs_clear(cmd.instance_id) + + async def on_FEDERATION_ACK(self, cmd: FederationAckCommand): + federation_ack_counter.inc() + + if self.federation_sender: + self.federation_sender.federation_ack(cmd.token) + + async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand): + remove_pusher_counter.inc() + + if self.is_master: + await self.store.delete_pusher_by_app_id_pushkey_user_id( + app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id + ) + + self.notifier.on_new_replication_data() + + async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand): + invalidate_cache_counter.inc() + + if self.is_master: + # We invalidate the cache locally, but then also stream that to other + # workers. + await self.store.invalidate_cache_and_stream( + cmd.cache_func, tuple(cmd.keys) + ) + + async def on_USER_IP(self, cmd: UserIpCommand): + user_ip_cache_counter.inc() + + if self.is_master: + await self.store.insert_client_ip( + cmd.user_id, + cmd.access_token, + cmd.ip, + cmd.user_agent, + cmd.device_id, + cmd.last_seen, + ) + await self._server_notices_sender.on_user_ip(cmd.user_id) + + async def on_RDATA(self, cmd: RdataCommand): + stream_name = cmd.stream_name + + try: + row = STREAMS_MAP[stream_name].parse_row(cmd.row) + except Exception: + logger.exception("[%s] Failed to parse RDATA: %r", stream_name, cmd.row) + raise + + if cmd.token is None: + # I.e. this is part of a batch of updates for this stream. Batch + # until we get an update for the stream with a non None token + self.pending_batches.setdefault(stream_name, []).append(row) + else: + # Check if this is the last of a batch of updates + rows = self.pending_batches.pop(stream_name, []) + rows.append(row) + await self.on_rdata(stream_name, cmd.token, rows) + + async def on_rdata(self, stream_name: str, token: int, rows: list): + """Called to handle a batch of replication data with a given stream token. + + Args: + stream_name: name of the replication stream for this batch of rows + token: stream token for this batch of rows + rows: a list of Stream.ROW_TYPE objects as returned by + Stream.parse_row. + """ + logger.debug("Received rdata %s -> %s", stream_name, token) + await self.replication_data_handler.on_rdata(stream_name, token, rows) + + async def on_POSITION(self, cmd: PositionCommand): + stream = self.streams.get(cmd.stream_name) + if not stream: + logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) + return + + # Find where we previously streamed up to. + current_token = self.replication_data_handler.get_streams_to_replicate().get( + cmd.stream_name + ) + if current_token is None: + logger.debug( + "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name + ) + return + + # Fetch all updates between then and now. + limited = cmd.token != current_token + while limited: + updates, current_token, limited = await stream.get_updates_since( + current_token, cmd.token + ) + if updates: + await self.on_rdata( + cmd.stream_name, + current_token, + [stream.parse_row(update[1]) for update in updates], + ) + + # We've now caught up to position sent to us, notify handler. + await self.replication_data_handler.on_position(cmd.stream_name, cmd.token) + + # Handle any RDATA that came in while we were catching up. + rows = self.pending_batches.pop(cmd.stream_name, []) + if rows: + await self.on_rdata(cmd.stream_name, rows[-1].token, rows) + + async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): + """Called when get a new REMOTE_SERVER_UP command.""" + if self.is_master: + self.notifier.notify_remote_server_up(cmd.server) + + def get_currently_syncing_users(self): + """Get the list of currently syncing users (if any). This is called + when a connection has been established and we need to send the + currently syncing users. + """ + return self.presence_handler.get_currently_syncing_users() + + def send_command(self, cmd: Command): + """Send a command to master (when we get establish a connection if we + don't have one already.) + """ + for conn in self.connections: + conn.send_command(cmd) + + def send_federation_ack(self, token: int): + """Ack data for the federation stream. This allows the master to drop + data stored purely in memory. + """ + self.send_command(FederationAckCommand(token)) + + def send_user_sync( + self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int + ): + """Poke the master that a user has started/stopped syncing. + """ + self.send_command( + UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) + ) + + def send_remove_pusher(self, app_id: str, push_key: str, user_id: str): + """Poke the master to remove a pusher for a user + """ + cmd = RemovePusherCommand(app_id, push_key, user_id) + self.send_command(cmd) + + def send_invalidate_cache(self, cache_func: str, keys: tuple): + """Poke the master to invalidate a cache. + """ + cmd = InvalidateCacheCommand(cache_func.__name__, keys) + self.send_command(cmd) + + def send_user_ip( + self, + user_id: str, + access_token: str, + ip: str, + user_agent: str, + device_id: str, + last_seen: int, + ): + """Tell the master that the user made a request. + """ + cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) + self.send_command(cmd) + + def send_remote_server_up(self, server: str): + self.send_command(RemoteServerUpCommand(server)) + + def stream_update(self, stream_name: str, token: str, data: Any): + """Called when a new update is available to stream to clients. + + We need to check if the client is interested in the stream or not + """ + self.send_command(RdataCommand(stream_name, token, data)) + + +class DummyReplicationDataHandler: + """A replication data handler that simply discards all data. + """ + + async def on_rdata(self, stream_name: str, token: int, rows: list): + """Called to handle a batch of replication data with a given stream token. + + By default this just pokes the slave store. Can be overridden in subclasses to + handle more. + + Args: + stream_name (str): name of the replication stream for this batch of rows + token (int): stream token for this batch of rows + rows (list): a list of Stream.ROW_TYPE objects as returned by + Stream.parse_row. + """ + pass + + def get_streams_to_replicate(self) -> Dict[str, int]: + """Called when a new connection has been established and we need to + subscribe to streams. + + Returns: + map from stream name to the most recent update we have for + that stream (ie, the point we want to start replicating from) + """ + return {} + + async def on_position(self, stream_name: str, token: int): + pass + + +class WorkerReplicationDataHandler: + """A replication data handler that calls slave data stores. + """ + + def __init__(self, store): + self.store = store + + async def on_rdata(self, stream_name: str, token: int, rows: list): + """Called to handle a batch of replication data with a given stream token. + + By default this just pokes the slave store. Can be overridden in subclasses to + handle more. + + Args: + stream_name (str): name of the replication stream for this batch of rows + token (int): stream token for this batch of rows + rows (list): a list of Stream.ROW_TYPE objects as returned by + Stream.parse_row. + """ + self.store.process_replication_rows(stream_name, token, rows) + + def get_streams_to_replicate(self) -> Dict[str, int]: + """Called when a new connection has been established and we need to + subscribe to streams. + + Returns: + map from stream name to the most recent update we have for + that stream (ie, the point we want to start replicating from) + """ + args = self.store.stream_positions() + user_account_data = args.pop("user_account_data", None) + room_account_data = args.pop("room_account_data", None) + if user_account_data: + args["account_data"] = user_account_data + elif room_account_data: + args["account_data"] = room_account_data + return args + + async def on_position(self, stream_name: str, token: int): + self.store.process_replication_rows(stream_name, token, []) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index ff720beb56..d4456f42f3 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -46,7 +46,6 @@ indicate which side is sending, these are *not* included on the wire:: > ERROR server stopping * connection closed by server * """ -import abc import fcntl import logging import struct @@ -64,26 +63,22 @@ from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import ( COMMAND_MAP, - VALID_CLIENT_COMMANDS, - VALID_SERVER_COMMANDS, Command, ErrorCommand, NameCommand, PingCommand, - PositionCommand, - RdataCommand, RemoteServerUpCommand, ReplicateCommand, ServerCommand, - SyncCommand, - UserSyncCommand, ) from synapse.replication.tcp.streams import STREAMS_MAP, Stream -from synapse.server import HomeServer -from synapse.types import Collection from synapse.util import Clock from synapse.util.stringutils import random_string +MYPY = False +if MYPY: + import synapse.server + connection_close_counter = Counter( "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] ) @@ -124,16 +119,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): delimiter = b"\n" - # Valid commands we expect to receive - VALID_INBOUND_COMMANDS = [] # type: Collection[str] - - # Valid commands we can send - VALID_OUTBOUND_COMMANDS = [] # type: Collection[str] - max_line_buffer = 10000 - def __init__(self, clock): + def __init__(self, clock, handler): self.clock = clock + self.handler = handler self.last_received_command = self.clock.time_msec() self.last_sent_command = 0 @@ -173,6 +163,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # can time us out. self.send_command(PingCommand(self.clock.time_msec())) + self.handler.new_connection(self) + def send_ping(self): """Periodically sends a ping and checks if we should close the connection due to the other side timing out. @@ -210,11 +202,6 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): line = line.decode("utf-8") cmd_name, rest_of_line = line.split(" ", 1) - if cmd_name not in self.VALID_INBOUND_COMMANDS: - logger.error("[%s] invalid command %s", self.id(), cmd_name) - self.send_error("invalid command: %s", cmd_name) - return - self.last_received_command = self.clock.time_msec() self.inbound_commands_counter[cmd_name] = ( @@ -246,8 +233,23 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): Args: cmd: received command """ - handler = getattr(self, "on_%s" % (cmd.NAME,)) - await handler(cmd) + handled = False + + # First call any command handlers on this instance. These are for TCP + # specific handling. + cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + # Then call out to the handler. + cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + if not handled: + logger.warning("Unhandled command: %r", cmd) def close(self): logger.warning("[%s] Closing connection", self.id()) @@ -255,6 +257,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.transport.loseConnection() self.on_connection_closed() + def send_remote_server_up(self, server: str): + self.send_command(RemoteServerUpCommand(server)) + def send_error(self, error_string, *args): """Send an error to remote and close the connection. """ @@ -376,6 +381,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.state = ConnectionStates.CLOSED self.pending_commands = [] + self.handler.lost_connection(self) + if self.transport: self.transport.unregisterProducer() @@ -399,162 +406,35 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): - VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS - VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS - - def __init__(self, server_name, clock, streamer): - BaseReplicationStreamProtocol.__init__(self, clock) # Old style class + def __init__(self, hs, server_name, clock, handler): + BaseReplicationStreamProtocol.__init__(self, clock, handler) # Old style class self.server_name = server_name - self.streamer = streamer def connectionMade(self): self.send_command(ServerCommand(self.server_name)) BaseReplicationStreamProtocol.connectionMade(self) - self.streamer.new_connection(self) async def on_NAME(self, cmd): logger.info("[%s] Renamed to %r", self.id(), cmd.data) self.name = cmd.data - async def on_USER_SYNC(self, cmd): - await self.streamer.on_user_sync( - cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms - ) - - async def on_CLEAR_USER_SYNC(self, cmd): - await self.streamer.on_clear_user_syncs(cmd.instance_id) - - async def on_REPLICATE(self, cmd): - # Subscribe to all streams we're publishing to. - for stream_name in self.streamer.streams_by_name: - current_token = self.streamer.get_stream_token(stream_name) - self.send_command(PositionCommand(stream_name, current_token)) - - async def on_FEDERATION_ACK(self, cmd): - self.streamer.federation_ack(cmd.token) - - async def on_REMOVE_PUSHER(self, cmd): - await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) - - async def on_INVALIDATE_CACHE(self, cmd): - await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) - - async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): - self.streamer.on_remote_server_up(cmd.data) - - async def on_USER_IP(self, cmd): - self.streamer.on_user_ip( - cmd.user_id, - cmd.access_token, - cmd.ip, - cmd.user_agent, - cmd.device_id, - cmd.last_seen, - ) - - def stream_update(self, stream_name, token, data): - """Called when a new update is available to stream to clients. - - We need to check if the client is interested in the stream or not - """ - self.send_command(RdataCommand(stream_name, token, data)) - - def send_sync(self, data): - self.send_command(SyncCommand(data)) - - def send_remote_server_up(self, server: str): - self.send_command(RemoteServerUpCommand(server)) - - def on_connection_closed(self): - BaseReplicationStreamProtocol.on_connection_closed(self) - self.streamer.lost_connection(self) - - -class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): - """ - The interface for the handler that should be passed to - ClientReplicationStreamProtocol - """ - - @abc.abstractmethod - async def on_rdata(self, stream_name, token, rows): - """Called to handle a batch of replication data with a given stream token. - - Args: - stream_name (str): name of the replication stream for this batch of rows - token (int): stream token for this batch of rows - rows (list): a list of Stream.ROW_TYPE objects as returned by - Stream.parse_row. - """ - raise NotImplementedError() - - @abc.abstractmethod - async def on_position(self, stream_name, token): - """Called when we get new position data.""" - raise NotImplementedError() - - @abc.abstractmethod - def on_sync(self, data): - """Called when get a new SYNC command.""" - raise NotImplementedError() - - @abc.abstractmethod - async def on_remote_server_up(self, server: str): - """Called when get a new REMOTE_SERVER_UP command.""" - raise NotImplementedError() - - @abc.abstractmethod - def get_streams_to_replicate(self): - """Called when a new connection has been established and we need to - subscribe to streams. - - Returns: - map from stream name to the most recent update we have for - that stream (ie, the point we want to start replicating from) - """ - raise NotImplementedError() - - @abc.abstractmethod - def get_currently_syncing_users(self): - """Get the list of currently syncing users (if any). This is called - when a connection has been established and we need to send the - currently syncing users.""" - raise NotImplementedError() - - @abc.abstractmethod - def update_connection(self, connection): - """Called when a connection has been established (or lost with None). - """ - raise NotImplementedError() - - @abc.abstractmethod - def finished_connecting(self): - """Called when we have successfully subscribed and caught up to all - streams we're interested in. - """ - raise NotImplementedError() - class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): - VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS - VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS - def __init__( self, - hs: HomeServer, + hs: "synapse.server.HomeServer", client_name: str, server_name: str, clock: Clock, - handler: AbstractReplicationClientHandler, + handler, ): - BaseReplicationStreamProtocol.__init__(self, clock) + BaseReplicationStreamProtocol.__init__(self, clock, handler) self.instance_id = hs.get_instance_id() self.client_name = client_name self.server_name = server_name - self.handler = handler self.streams = { stream.NAME: stream(hs) for stream in STREAMS_MAP.values() @@ -570,106 +450,16 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.pending_batches = {} # type: Dict[str, List[Any]] def connectionMade(self): - self.send_command(NameCommand(self.client_name)) BaseReplicationStreamProtocol.connectionMade(self) - # Once we've connected subscribe to the necessary streams + self.send_command(NameCommand(self.client_name)) self.replicate() - # Tell the server if we have any users currently syncing (should only - # happen on synchrotrons) - currently_syncing = self.handler.get_currently_syncing_users() - now = self.clock.time_msec() - for user_id in currently_syncing: - self.send_command(UserSyncCommand(self.instance_id, user_id, True, now)) - - # We've now finished connecting to so inform the client handler - self.handler.update_connection(self) - async def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) self.send_error("Wrong remote") - async def on_RDATA(self, cmd): - stream_name = cmd.stream_name - inbound_rdata_count.labels(stream_name).inc() - - try: - row = STREAMS_MAP[stream_name].parse_row(cmd.row) - except Exception: - logger.exception( - "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row - ) - raise - - if cmd.token is None or stream_name in self.streams_connecting: - # I.e. this is part of a batch of updates for this stream. Batch - # until we get an update for the stream with a non None token - self.pending_batches.setdefault(stream_name, []).append(row) - else: - # Check if this is the last of a batch of updates - rows = self.pending_batches.pop(stream_name, []) - rows.append(row) - await self.handler.on_rdata(stream_name, cmd.token, rows) - - async def on_POSITION(self, cmd: PositionCommand): - stream = self.streams.get(cmd.stream_name) - if not stream: - logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) - return - - # Find where we previously streamed up to. - current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name) - if current_token is None: - logger.warning( - "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name - ) - return - - # Fetch all updates between then and now. - limited = True - while limited: - updates, current_token, limited = await stream.get_updates_since( - current_token, cmd.token - ) - - # Check if the connection was closed underneath us, if so we bail - # rather than risk having concurrent catch ups going on. - if self.state == ConnectionStates.CLOSED: - return - - if updates: - await self.handler.on_rdata( - cmd.stream_name, - current_token, - [stream.parse_row(update[1]) for update in updates], - ) - - # We've now caught up to position sent to us, notify handler. - await self.handler.on_position(cmd.stream_name, cmd.token) - - # We're now up to date wit the stream - self.streams_connecting.discard(cmd.stream_name) - if not self.streams_connecting: - self.handler.finished_connecting() - - # Check if the connection was closed underneath us, if so we bail - # rather than risk having concurrent catch ups going on. - if self.state == ConnectionStates.CLOSED: - return - - # Handle any RDATA that came in while we were catching up. - rows = self.pending_batches.pop(cmd.stream_name, []) - if rows: - await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows) - - async def on_SYNC(self, cmd): - self.handler.on_sync(cmd.data) - - async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): - self.handler.on_remote_server_up(cmd.data) - def replicate(self): """Send the subscription request to the server """ @@ -677,10 +467,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.send_command(ReplicateCommand()) - def on_connection_closed(self): - BaseReplicationStreamProtocol.on_connection_closed(self) - self.handler.update_connection(None) - # The following simply registers metrics for the replication connections diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index acf8868de9..28edbbdac9 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -17,7 +17,6 @@ import logging import random -from typing import Any, List from six import itervalues @@ -25,9 +24,8 @@ from prometheus_client import Counter from twisted.internet.protocol import Factory -from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.metrics import Measure, measure_func +from synapse.util.metrics import Measure from .protocol import ServerReplicationStreamProtocol from .streams import STREAMS_MAP @@ -36,13 +34,6 @@ from .streams.federation import FederationStream stream_updates_counter = Counter( "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"] ) -user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") -federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") -remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") -invalidate_cache_counter = Counter( - "synapse_replication_tcp_resource_invalidate_cache", "" -) -user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") logger = logging.getLogger(__name__) @@ -52,13 +43,18 @@ class ReplicationStreamProtocolFactory(Factory): """ def __init__(self, hs): - self.streamer = ReplicationStreamer(hs) + self.handler = hs.get_tcp_replication() self.clock = hs.get_clock() self.server_name = hs.config.server_name + self.hs = hs + + # Ensure the replication streamer is started if we register a + # replication server endpoint. + hs.get_replication_streamer() def buildProtocol(self, addr): return ServerReplicationStreamProtocol( - self.server_name, self.clock, self.streamer + self.hs, self.server_name, self.clock, self.handler ) @@ -78,16 +74,6 @@ class ReplicationStreamer(object): self._replication_torture_level = hs.config.replication_torture_level - # Current connections. - self.connections = [] # type: List[ServerReplicationStreamProtocol] - - LaterGauge( - "synapse_replication_tcp_resource_total_connections", - "", - [], - lambda: len(self.connections), - ) - # List of streams that clients can subscribe to. # We only support federation stream if federation sending hase been # disabled on the master. @@ -99,39 +85,17 @@ class ReplicationStreamer(object): self.streams_by_name = {stream.NAME: stream for stream in self.streams} - LaterGauge( - "synapse_replication_tcp_resource_connections_per_stream", - "", - ["stream_name"], - lambda: { - (stream_name,): len( - [ - conn - for conn in self.connections - if stream_name in conn.replication_streams - ] - ) - for stream_name in self.streams_by_name - }, - ) - self.federation_sender = None if not hs.config.send_federation: self.federation_sender = hs.get_federation_sender() self.notifier.add_replication_callback(self.on_notifier_poke) - self.notifier.add_remote_server_up_callback(self.send_remote_server_up) # Keeps track of whether we are currently checking for updates self.is_looping = False self.pending_updates = False - hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown) - - def on_shutdown(self): - # close all connections on shutdown - for conn in self.connections: - conn.send_error("server shutting down") + self.client = hs.get_tcp_replication() def on_notifier_poke(self): """Checks if there is actually any new data and sends it to the @@ -140,7 +104,7 @@ class ReplicationStreamer(object): This should get called each time new data is available, even if it is currently being executed, so that nothing gets missed """ - if not self.connections: + if not self.client.connected(): # Don't bother if nothing is listening. We still need to advance # the stream tokens otherwise they'll fall beihind forever for stream in self.streams: @@ -197,9 +161,7 @@ class ReplicationStreamer(object): raise logger.debug( - "Sending %d updates to %d connections", - len(updates), - len(self.connections), + "Sending %d updates", len(updates), ) if updates: @@ -215,112 +177,17 @@ class ReplicationStreamer(object): # token. See RdataCommand for more details. batched_updates = _batch_updates(updates) - for conn in self.connections: - for token, row in batched_updates: - try: - conn.stream_update(stream.NAME, token, row) - except Exception: - logger.exception("Failed to replicate") + for token, row in batched_updates: + try: + self.client.stream_update(stream.NAME, token, row) + except Exception: + logger.exception("Failed to replicate") logger.debug("No more pending updates, breaking poke loop") finally: self.pending_updates = False self.is_looping = False - def get_stream_token(self, stream_name): - """For a given stream get all updates since token. This is called when - a client first subscribes to a stream. - """ - stream = self.streams_by_name.get(stream_name, None) - if not stream: - raise Exception("unknown stream %s", stream_name) - - return stream.current_token() - - @measure_func("repl.federation_ack") - def federation_ack(self, token): - """We've received an ack for federation stream from a client. - """ - federation_ack_counter.inc() - if self.federation_sender: - self.federation_sender.federation_ack(token) - - @measure_func("repl.on_user_sync") - async def on_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms): - """A client has started/stopped syncing on a worker. - """ - user_sync_counter.inc() - await self.presence_handler.update_external_syncs_row( - instance_id, user_id, is_syncing, last_sync_ms - ) - - async def on_clear_user_syncs(self, instance_id): - """A replication client wants us to drop all their UserSync data. - """ - await self.presence_handler.update_external_syncs_clear(instance_id) - - @measure_func("repl.on_remove_pusher") - async def on_remove_pusher(self, app_id, push_key, user_id): - """A client has asked us to remove a pusher - """ - remove_pusher_counter.inc() - await self.store.delete_pusher_by_app_id_pushkey_user_id( - app_id=app_id, pushkey=push_key, user_id=user_id - ) - - self.notifier.on_new_replication_data() - - @measure_func("repl.on_invalidate_cache") - async def on_invalidate_cache(self, cache_func: str, keys: List[Any]): - """The client has asked us to invalidate a cache - """ - invalidate_cache_counter.inc() - - # We invalidate the cache locally, but then also stream that to other - # workers. - await self.store.invalidate_cache_and_stream(cache_func, tuple(keys)) - - @measure_func("repl.on_user_ip") - async def on_user_ip( - self, user_id, access_token, ip, user_agent, device_id, last_seen - ): - """The client saw a user request - """ - user_ip_cache_counter.inc() - await self.store.insert_client_ip( - user_id, access_token, ip, user_agent, device_id, last_seen - ) - await self._server_notices_sender.on_user_ip(user_id) - - @measure_func("repl.on_remote_server_up") - def on_remote_server_up(self, server: str): - self.notifier.notify_remote_server_up(server) - - def send_remote_server_up(self, server: str): - for conn in self.connections: - conn.send_remote_server_up(server) - - def send_sync_to_all_connections(self, data): - """Sends a SYNC command to all clients. - - Used in tests. - """ - for conn in self.connections: - conn.send_sync(data) - - def new_connection(self, connection): - """A new client connection has been established - """ - self.connections.append(connection) - - def lost_connection(self, connection): - """A client connection has been lost - """ - try: - self.connections.remove(connection) - except ValueError: - pass - def _batch_updates(updates): """Takes a list of updates of form [(token, row)] and sets the token to diff --git a/synapse/server.py b/synapse/server.py index 392f9e8e38..5f5d79161c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -85,6 +85,11 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.notifier import Notifier from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool +from synapse.replication.tcp.handler import ( + DummyReplicationDataHandler, + ReplicationClientHandler, +) +from synapse.replication.tcp.resource import ReplicationStreamer from synapse.rest.media.v1.media_repository import ( MediaRepository, MediaRepositoryResource, @@ -200,6 +205,8 @@ class HomeServer(object): "saml_handler", "event_client_serializer", "storage", + "replication_streamer", + "replication_data_handler", ] REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] @@ -459,7 +466,7 @@ class HomeServer(object): return ReadMarkerHandler(self) def build_tcp_replication(self): - raise NotImplementedError() + return ReplicationClientHandler(self) def build_action_generator(self): return ActionGenerator(self) @@ -544,6 +551,12 @@ class HomeServer(object): def build_storage(self) -> Storage: return Storage(self, self.datastores) + def build_replication_streamer(self) -> ReplicationStreamer: + return ReplicationStreamer(self) + + def build_replication_data_handler(self): + return DummyReplicationDataHandler() + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 2a1e7c7166..f2c0e381c1 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -15,9 +15,10 @@ from mock import Mock, NonCallableMock -from synapse.replication.tcp.client import ( - ReplicationClientFactory, +from synapse.replication.tcp.client import ReplicationClientFactory +from synapse.replication.tcp.handler import ( ReplicationClientHandler, + WorkerReplicationDataHandler, ) from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.storage.database import make_conn @@ -51,16 +52,19 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): self.event_id = 0 server_factory = ReplicationStreamProtocolFactory(self.hs) - self.streamer = server_factory.streamer + self.streamer = hs.get_replication_streamer() - handler_factory = Mock() - self.replication_handler = ReplicationClientHandler(self.slaved_store) - self.replication_handler.factory = handler_factory - - client_factory = ReplicationClientFactory( - self.hs, "client_name", self.replication_handler + # We now do some gut wrenching so that we have a client that is based + # off of the slave store rather than the main store. + self.replication_handler = ReplicationClientHandler(self.hs) + self.replication_handler.store = self.slaved_store + self.replication_handler.replication_data_handler = WorkerReplicationDataHandler( + self.slaved_store ) + client_factory = ReplicationClientFactory(self.hs, "client_name") + client_factory.handler = self.replication_handler + server = server_factory.buildProtocol(None) client = client_factory.buildProtocol(None) diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index a755fe2879..2d6e44f625 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -15,7 +15,7 @@ from mock import Mock -from synapse.replication.tcp.commands import ReplicateCommand +from synapse.replication.tcp.handler import ReplicationClientHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory @@ -26,15 +26,20 @@ from tests.server import FakeTransport class BaseStreamTestCase(unittest.HomeserverTestCase): """Base class for tests of the replication streams""" + def make_homeserver(self, reactor, clock): + self.test_handler = Mock(wraps=TestReplicationClientHandler()) + return self.setup_test_homeserver(replication_data_handler=self.test_handler) + def prepare(self, reactor, clock, hs): # build a replication server - server_factory = ReplicationStreamProtocolFactory(self.hs) - self.streamer = server_factory.streamer + server_factory = ReplicationStreamProtocolFactory(hs) + self.streamer = hs.get_replication_streamer() self.server = server_factory.buildProtocol(None) - self.test_handler = Mock(wraps=TestReplicationClientHandler()) + repl_handler = ReplicationClientHandler(hs) + repl_handler.handler = self.test_handler self.client = ClientReplicationStreamProtocol( - hs, "client", "test", clock, self.test_handler, + hs, "client", "test", clock, repl_handler, ) self._client_transport = None @@ -69,14 +74,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.streamer.on_notifier_poke() self.pump(0.1) - def replicate_stream(self): - """Make the client end a REPLICATE command to set up a subscription to a stream""" - self.client.send_command(ReplicateCommand()) - - -class TestReplicationClientHandler(object): - """Drop-in for ReplicationClientHandler which just collects RDATA rows""" +class TestReplicationClientHandler: def __init__(self): self.streams = set() self._received_rdata_rows = [] @@ -88,18 +87,9 @@ class TestReplicationClientHandler(object): positions[stream] = max(token, positions.get(stream, 0)) return positions - def get_currently_syncing_users(self): - return [] - - def update_connection(self, connection): - pass - - def finished_connecting(self): - pass - - async def on_position(self, stream_name, token): - """Called when we get new position data.""" - async def on_rdata(self, stream_name, token, rows): for r in rows: self._received_rdata_rows.append((stream_name, token, r)) + + async def on_position(self, stream_name, token): + pass diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index 0ec0825a0e..a0206f7363 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -24,7 +24,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.reconnect() # make the client subscribe to the receipts stream - self.replicate_stream() self.test_handler.streams.add("receipts") # tell the master to send a new receipt