Add support for moving /push_rules off of main process (#17037)

This commit is contained in:
Erik Johnston 2024-03-28 15:44:07 +00:00 committed by GitHub
parent 59ceabcb97
commit ea6bfae0fc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 133 additions and 40 deletions

View file

@ -0,0 +1 @@
Add support for moving `/push_rules` off of main process.

View file

@ -532,6 +532,13 @@ the stream writer for the `presence` stream:
^/_matrix/client/(api/v1|r0|v3|unstable)/presence/ ^/_matrix/client/(api/v1|r0|v3|unstable)/presence/
##### The `push` stream
The following endpoints should be routed directly to the worker configured as
the stream writer for the `push` stream:
^/_matrix/client/(api/v1|r0|v3|unstable)/push_rules/
#### Restrict outbound federation traffic to a specific set of workers #### Restrict outbound federation traffic to a specific set of workers
The The

View file

@ -60,7 +60,7 @@ from synapse.logging.context import (
) )
from synapse.notifier import ReplicationNotifier from synapse.notifier import ReplicationNotifier
from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
from synapse.storage.databases.main import FilteringWorkerStore, PushRuleStore from synapse.storage.databases.main import FilteringWorkerStore
from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore
@ -77,10 +77,8 @@ from synapse.storage.databases.main.media_repository import (
) )
from synapse.storage.databases.main.presence import PresenceBackgroundUpdateStore from synapse.storage.databases.main.presence import PresenceBackgroundUpdateStore
from synapse.storage.databases.main.profile import ProfileWorkerStore from synapse.storage.databases.main.profile import ProfileWorkerStore
from synapse.storage.databases.main.pusher import ( from synapse.storage.databases.main.push_rule import PusherWorkerStore
PusherBackgroundUpdatesStore, from synapse.storage.databases.main.pusher import PusherBackgroundUpdatesStore
PusherWorkerStore,
)
from synapse.storage.databases.main.receipts import ReceiptsBackgroundUpdateStore from synapse.storage.databases.main.receipts import ReceiptsBackgroundUpdateStore
from synapse.storage.databases.main.registration import ( from synapse.storage.databases.main.registration import (
RegistrationBackgroundUpdateStore, RegistrationBackgroundUpdateStore,
@ -245,7 +243,6 @@ class Store(
AccountDataWorkerStore, AccountDataWorkerStore,
FilteringWorkerStore, FilteringWorkerStore,
ProfileWorkerStore, ProfileWorkerStore,
PushRuleStore,
PusherWorkerStore, PusherWorkerStore,
PusherBackgroundUpdatesStore, PusherBackgroundUpdatesStore,
PresenceBackgroundUpdateStore, PresenceBackgroundUpdateStore,

View file

@ -156,6 +156,8 @@ class WriterLocations:
can only be a single instance. can only be a single instance.
presence: The instances that write to the presence stream. Currently presence: The instances that write to the presence stream. Currently
can only be a single instance. can only be a single instance.
push: The instances that write to the push stream. Currently
can only be a single instance.
""" """
events: List[str] = attr.ib( events: List[str] = attr.ib(
@ -182,6 +184,10 @@ class WriterLocations:
default=["master"], default=["master"],
converter=_instance_to_list_converter, converter=_instance_to_list_converter,
) )
push: List[str] = attr.ib(
default=["master"],
converter=_instance_to_list_converter,
)
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
@ -341,6 +347,7 @@ class WorkerConfig(Config):
"account_data", "account_data",
"receipts", "receipts",
"presence", "presence",
"push",
): ):
instances = _instance_to_list_converter(getattr(self.writers, stream)) instances = _instance_to_list_converter(getattr(self.writers, stream))
for instance in instances: for instance in instances:
@ -378,6 +385,11 @@ class WorkerConfig(Config):
"Must only specify one instance to handle `presence` messages." "Must only specify one instance to handle `presence` messages."
) )
if len(self.writers.push) != 1:
raise ConfigError(
"Must only specify one instance to handle `push` messages."
)
self.events_shard_config = RoutableShardedWorkerHandlingConfig( self.events_shard_config = RoutableShardedWorkerHandlingConfig(
self.writers.events self.writers.events
) )

View file

@ -51,6 +51,7 @@ from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.logging import opentracing from synapse.logging import opentracing
from synapse.metrics import event_processing_positions from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.push import ReplicationCopyPusherRestServlet
from synapse.storage.databases.main.state_deltas import StateDelta from synapse.storage.databases.main.state_deltas import StateDelta
from synapse.types import ( from synapse.types import (
JsonDict, JsonDict,
@ -181,6 +182,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
hs.config.server.forgotten_room_retention_period hs.config.server.forgotten_room_retention_period
) )
self._is_push_writer = hs.get_instance_name() in hs.config.worker.writers.push
self._push_writer = hs.config.worker.writers.push[0]
self._copy_push_client = ReplicationCopyPusherRestServlet.make_client(hs)
def _on_user_joined_room(self, event_id: str, room_id: str) -> None: def _on_user_joined_room(self, event_id: str, room_id: str) -> None:
"""Notify the rate limiter that a room join has occurred. """Notify the rate limiter that a room join has occurred.
@ -1301,9 +1306,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
old_room_id, new_room_id, user_id old_room_id, new_room_id, user_id
) )
# Copy over push rules # Copy over push rules
await self.store.copy_push_rules_from_room_to_room_for_user( if self._is_push_writer:
old_room_id, new_room_id, user_id await self.store.copy_push_rules_from_room_to_room_for_user(
) old_room_id, new_room_id, user_id
)
else:
await self._copy_push_client(
instance_name=self._push_writer,
user_id=user_id,
old_room_id=old_room_id,
new_room_id=new_room_id,
)
except Exception: except Exception:
logger.exception( logger.exception(
"Error copying tags and/or push rules from rooms %s to %s for user %s. " "Error copying tags and/or push rules from rooms %s to %s for user %s. "

View file

@ -77,5 +77,46 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
return 200, {} return 200, {}
class ReplicationCopyPusherRestServlet(ReplicationEndpoint):
"""Copies push rules from an old room to new room.
Request format:
POST /_synapse/replication/copy_push_rules/:user_id/:old_room_id/:new_room_id
{}
"""
NAME = "copy_push_rules"
PATH_ARGS = ("user_id", "old_room_id", "new_room_id")
CACHE = False
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._store = hs.get_datastores().main
@staticmethod
async def _serialize_payload(user_id: str, old_room_id: str, new_room_id: str) -> JsonDict: # type: ignore[override]
return {}
async def _handle_request( # type: ignore[override]
self,
request: Request,
content: JsonDict,
user_id: str,
old_room_id: str,
new_room_id: str,
) -> Tuple[int, JsonDict]:
await self._store.copy_push_rules_from_room_to_room_for_user(
old_room_id, new_room_id, user_id
)
return 200, {}
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationRemovePusherRestServlet(hs).register(http_server) ReplicationRemovePusherRestServlet(hs).register(http_server)
ReplicationCopyPusherRestServlet(hs).register(http_server)

View file

@ -66,6 +66,7 @@ from synapse.replication.tcp.streams import (
FederationStream, FederationStream,
PresenceFederationStream, PresenceFederationStream,
PresenceStream, PresenceStream,
PushRulesStream,
ReceiptsStream, ReceiptsStream,
Stream, Stream,
ToDeviceStream, ToDeviceStream,
@ -178,6 +179,12 @@ class ReplicationCommandHandler:
continue continue
if isinstance(stream, PushRulesStream):
if hs.get_instance_name() in hs.config.worker.writers.push:
self._streams_to_replicate.append(stream)
continue
# Only add any other streams if we're on master. # Only add any other streams if we're on master.
if hs.config.worker.worker_app is not None: if hs.config.worker.worker_app is not None:
continue continue

View file

@ -59,12 +59,12 @@ class PushRuleRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker.worker_app is not None self._is_push_worker = hs.get_instance_name() in hs.config.worker.writers.push
self._push_rules_handler = hs.get_push_rules_handler() self._push_rules_handler = hs.get_push_rules_handler()
self._push_rule_linearizer = Linearizer(name="push_rules") self._push_rule_linearizer = Linearizer(name="push_rules")
async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]: async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
if self._is_worker: if not self._is_push_worker:
raise Exception("Cannot handle PUT /push_rules on worker") raise Exception("Cannot handle PUT /push_rules on worker")
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
@ -137,7 +137,7 @@ class PushRuleRestServlet(RestServlet):
async def on_DELETE( async def on_DELETE(
self, request: SynapseRequest, path: str self, request: SynapseRequest, path: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
if self._is_worker: if not self._is_push_worker:
raise Exception("Cannot handle DELETE /push_rules on worker") raise Exception("Cannot handle DELETE /push_rules on worker")
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)

View file

@ -63,7 +63,7 @@ from .openid import OpenIdStore
from .presence import PresenceStore from .presence import PresenceStore
from .profile import ProfileStore from .profile import ProfileStore
from .purge_events import PurgeEventsStore from .purge_events import PurgeEventsStore
from .push_rule import PushRuleStore from .push_rule import PushRulesWorkerStore
from .pusher import PusherStore from .pusher import PusherStore
from .receipts import ReceiptsStore from .receipts import ReceiptsStore
from .registration import RegistrationStore from .registration import RegistrationStore
@ -130,7 +130,6 @@ class DataStore(
RejectionsStore, RejectionsStore,
FilteringWorkerStore, FilteringWorkerStore,
PusherStore, PusherStore,
PushRuleStore,
ApplicationServiceTransactionStore, ApplicationServiceTransactionStore,
EventPushActionsStore, EventPushActionsStore,
ServerMetricsStore, ServerMetricsStore,
@ -140,6 +139,7 @@ class DataStore(
SearchStore, SearchStore,
TagsStore, TagsStore,
AccountDataStore, AccountDataStore,
PushRulesWorkerStore,
StreamWorkerStore, StreamWorkerStore,
OpenIdStore, OpenIdStore,
ClientIpWorkerStore, ClientIpWorkerStore,

View file

@ -53,11 +53,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import ( from synapse.storage.util.id_generators import IdGenerator, StreamIdGenerator
AbstractStreamIdGenerator,
IdGenerator,
StreamIdGenerator,
)
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder, unwrapFirstError from synapse.util import json_encoder, unwrapFirstError
@ -130,6 +126,8 @@ class PushRulesWorkerStore(
`get_max_push_rules_stream_id` which can be called in the initializer. `get_max_push_rules_stream_id` which can be called in the initializer.
""" """
_push_rules_stream_id_gen: StreamIdGenerator
def __init__( def __init__(
self, self,
database: DatabasePool, database: DatabasePool,
@ -138,6 +136,8 @@ class PushRulesWorkerStore(
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._is_push_writer = hs.get_instance_name() in hs.config.worker.writers.push
# In the worker store this is an ID tracker which we overwrite in the non-worker # In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process. # class below that is used on the main process.
self._push_rules_stream_id_gen = StreamIdGenerator( self._push_rules_stream_id_gen = StreamIdGenerator(
@ -145,7 +145,7 @@ class PushRulesWorkerStore(
hs.get_replication_notifier(), hs.get_replication_notifier(),
"push_rules_stream", "push_rules_stream",
"stream_id", "stream_id",
is_writer=hs.config.worker.worker_app is None, is_writer=self._is_push_writer,
) )
push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
@ -162,6 +162,9 @@ class PushRulesWorkerStore(
prefilled_cache=push_rules_prefill, prefilled_cache=push_rules_prefill,
) )
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
def get_max_push_rules_stream_id(self) -> int: def get_max_push_rules_stream_id(self) -> int:
"""Get the position of the push rules stream. """Get the position of the push rules stream.
@ -383,23 +386,6 @@ class PushRulesWorkerStore(
"get_all_push_rule_updates", get_all_push_rule_updates_txn "get_all_push_rule_updates", get_all_push_rule_updates_txn
) )
class PushRuleStore(PushRulesWorkerStore):
# Because we have write access, this will be a StreamIdGenerator
# (see PushRulesWorkerStore.__init__)
_push_rules_stream_id_gen: AbstractStreamIdGenerator
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
async def add_push_rule( async def add_push_rule(
self, self,
user_id: str, user_id: str,
@ -410,6 +396,9 @@ class PushRuleStore(PushRulesWorkerStore):
before: Optional[str] = None, before: Optional[str] = None,
after: Optional[str] = None, after: Optional[str] = None,
) -> None: ) -> None:
if not self._is_push_writer:
raise Exception("Not a push writer")
conditions_json = json_encoder.encode(conditions) conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions) actions_json = json_encoder.encode(actions)
async with self._push_rules_stream_id_gen.get_next() as stream_id: async with self._push_rules_stream_id_gen.get_next() as stream_id:
@ -455,6 +444,9 @@ class PushRuleStore(PushRulesWorkerStore):
before: str, before: str,
after: str, after: str,
) -> None: ) -> None:
if not self._is_push_writer:
raise Exception("Not a push writer")
relative_to_rule = before or after relative_to_rule = before or after
sql = """ sql = """
@ -524,6 +516,9 @@ class PushRuleStore(PushRulesWorkerStore):
conditions_json: str, conditions_json: str,
actions_json: str, actions_json: str,
) -> None: ) -> None:
if not self._is_push_writer:
raise Exception("Not a push writer")
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
# Postgres doesn't do FOR UPDATE on aggregate functions, so select the rows first # Postgres doesn't do FOR UPDATE on aggregate functions, so select the rows first
# then re-select the count/max below. # then re-select the count/max below.
@ -575,6 +570,9 @@ class PushRuleStore(PushRulesWorkerStore):
actions_json: str, actions_json: str,
update_stream: bool = True, update_stream: bool = True,
) -> None: ) -> None:
if not self._is_push_writer:
raise Exception("Not a push writer")
"""Specialised version of simple_upsert_txn that picks a push_rule_id """Specialised version of simple_upsert_txn that picks a push_rule_id
using the _push_rule_id_gen if it needs to insert the rule. It assumes using the _push_rule_id_gen if it needs to insert the rule. It assumes
that the "push_rules" table is locked""" that the "push_rules" table is locked"""
@ -653,6 +651,8 @@ class PushRuleStore(PushRulesWorkerStore):
user_id: The matrix ID of the push rule owner user_id: The matrix ID of the push rule owner
rule_id: The rule_id of the rule to be deleted rule_id: The rule_id of the rule to be deleted
""" """
if not self._is_push_writer:
raise Exception("Not a push writer")
def delete_push_rule_txn( def delete_push_rule_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
@ -704,6 +704,9 @@ class PushRuleStore(PushRulesWorkerStore):
Raises: Raises:
RuleNotFoundException if the rule does not exist. RuleNotFoundException if the rule does not exist.
""" """
if not self._is_push_writer:
raise Exception("Not a push writer")
async with self._push_rules_stream_id_gen.get_next() as stream_id: async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token() event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
@ -727,6 +730,9 @@ class PushRuleStore(PushRulesWorkerStore):
enabled: bool, enabled: bool,
is_default_rule: bool, is_default_rule: bool,
) -> None: ) -> None:
if not self._is_push_writer:
raise Exception("Not a push writer")
new_id = self._push_rules_enable_id_gen.get_next() new_id = self._push_rules_enable_id_gen.get_next()
if not is_default_rule: if not is_default_rule:
@ -796,6 +802,9 @@ class PushRuleStore(PushRulesWorkerStore):
Raises: Raises:
RuleNotFoundException if the rule does not exist. RuleNotFoundException if the rule does not exist.
""" """
if not self._is_push_writer:
raise Exception("Not a push writer")
actions_json = json_encoder.encode(actions) actions_json = json_encoder.encode(actions)
def set_push_rule_actions_txn( def set_push_rule_actions_txn(
@ -865,6 +874,9 @@ class PushRuleStore(PushRulesWorkerStore):
op: str, op: str,
data: Optional[JsonDict] = None, data: Optional[JsonDict] = None,
) -> None: ) -> None:
if not self._is_push_writer:
raise Exception("Not a push writer")
values = { values = {
"stream_id": stream_id, "stream_id": stream_id,
"event_stream_ordering": event_stream_ordering, "event_stream_ordering": event_stream_ordering,
@ -882,9 +894,6 @@ class PushRuleStore(PushRulesWorkerStore):
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
) )
def get_max_push_rules_stream_id(self) -> int:
return self._push_rules_stream_id_gen.get_current_token()
async def copy_push_rule_from_room_to_room( async def copy_push_rule_from_room_to_room(
self, new_room_id: str, user_id: str, rule: PushRule self, new_room_id: str, user_id: str, rule: PushRule
) -> None: ) -> None:
@ -895,6 +904,9 @@ class PushRuleStore(PushRulesWorkerStore):
user_id : ID of user the push rule belongs to. user_id : ID of user the push rule belongs to.
rule: A push rule. rule: A push rule.
""" """
if not self._is_push_writer:
raise Exception("Not a push writer")
# Create new rule id # Create new rule id
rule_id_scope = "/".join(rule.rule_id.split("/")[:-1]) rule_id_scope = "/".join(rule.rule_id.split("/")[:-1])
new_rule_id = rule_id_scope + "/" + new_room_id new_rule_id = rule_id_scope + "/" + new_room_id
@ -930,6 +942,9 @@ class PushRuleStore(PushRulesWorkerStore):
new_room_id: ID of the new room. new_room_id: ID of the new room.
user_id: ID of user to copy push rules for. user_id: ID of user to copy push rules for.
""" """
if not self._is_push_writer:
raise Exception("Not a push writer")
# Retrieve push rules for this user # Retrieve push rules for this user
user_push_rules = await self.get_push_rules_for_user(user_id) user_push_rules = await self.get_push_rules_for_user(user_id)