diff --git a/changelog.d/17164.bugfix b/changelog.d/17164.bugfix new file mode 100644 index 0000000000..597e2f14b0 --- /dev/null +++ b/changelog.d/17164.bugfix @@ -0,0 +1 @@ +Fix deduplicating of membership events to not create unused state groups. diff --git a/changelog.d/17229.misc b/changelog.d/17229.misc new file mode 100644 index 0000000000..d411550786 --- /dev/null +++ b/changelog.d/17229.misc @@ -0,0 +1 @@ +Replaces all usages of `StreamIdGenerator` with `MultiWriterIdGenerator`. diff --git a/changelog.d/17242.misc b/changelog.d/17242.misc new file mode 100644 index 0000000000..5bd627da57 --- /dev/null +++ b/changelog.d/17242.misc @@ -0,0 +1 @@ +Clean out invalid destinations from `device_federation_outbox` table. diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 1e56f46911..3bb4a34938 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -777,22 +777,74 @@ class Porter: await self._setup_events_stream_seqs() await self._setup_sequence( "un_partial_stated_event_stream_sequence", - ("un_partial_stated_event_stream",), + [("un_partial_stated_event_stream", "stream_id")], ) await self._setup_sequence( - "device_inbox_sequence", ("device_inbox", "device_federation_outbox") + "device_inbox_sequence", + [ + ("device_inbox", "stream_id"), + ("device_federation_outbox", "stream_id"), + ], ) await self._setup_sequence( "account_data_sequence", - ("room_account_data", "room_tags_revisions", "account_data"), + [ + ("room_account_data", "stream_id"), + ("room_tags_revisions", "stream_id"), + ("account_data", "stream_id"), + ], + ) + await self._setup_sequence( + "receipts_sequence", + [ + ("receipts_linearized", "stream_id"), + ], + ) + await self._setup_sequence( + "presence_stream_sequence", + [ + ("presence_stream", "stream_id"), + ], ) - await self._setup_sequence("receipts_sequence", ("receipts_linearized",)) - await self._setup_sequence("presence_stream_sequence", ("presence_stream",)) await self._setup_auth_chain_sequence() await self._setup_sequence( "application_services_txn_id_seq", - ("application_services_txns",), - "txn_id", + [ + ( + "application_services_txns", + "txn_id", + ) + ], + ) + await self._setup_sequence( + "device_lists_sequence", + [ + ("device_lists_stream", "stream_id"), + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), + ("device_lists_remote_pending", "stream_id"), + ("device_lists_changes_converted_stream_position", "stream_id"), + ], + ) + await self._setup_sequence( + "e2e_cross_signing_keys_sequence", + [ + ("e2e_cross_signing_keys", "stream_id"), + ], + ) + await self._setup_sequence( + "push_rules_stream_sequence", + [ + ("push_rules_stream", "stream_id"), + ], + ) + await self._setup_sequence( + "pushers_sequence", + [ + ("pushers", "id"), + ("deleted_pushers", "stream_id"), + ], ) # Step 3. Get tables. @@ -1101,12 +1153,11 @@ class Porter: async def _setup_sequence( self, sequence_name: str, - stream_id_tables: Iterable[str], - column_name: str = "stream_id", + stream_id_tables: Iterable[Tuple[str, str]], ) -> None: """Set a sequence to the correct value.""" current_stream_ids = [] - for stream_id_table in stream_id_tables: + for stream_id_table, column_name in stream_id_tables: max_stream_id = cast( int, await self.sqlite_store.db_pool.simple_select_one_onecol( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ccaa5508ff..de5bd44a5f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -496,13 +496,6 @@ class EventCreationHandler: self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state - self.membership_types_to_include_profile_data_in = { - Membership.JOIN, - Membership.KNOCK, - } - if self.hs.config.server.include_profile_data_on_invite: - self.membership_types_to_include_profile_data_in.add(Membership.INVITE) - self.send_event = ReplicationSendEventRestServlet.make_client(hs) self.send_events = ReplicationSendEventsRestServlet.make_client(hs) @@ -594,8 +587,6 @@ class EventCreationHandler: Creates an FrozenEvent object, filling out auth_events, prev_events, etc. - Adds display names to Join membership events. - Args: requester event_dict: An entire event @@ -672,29 +663,6 @@ class EventCreationHandler: self.validator.validate_builder(builder) - if builder.type == EventTypes.Member: - membership = builder.content.get("membership", None) - target = UserID.from_string(builder.state_key) - - if membership in self.membership_types_to_include_profile_data_in: - # If event doesn't include a display name, add one. - profile = self.profile_handler - content = builder.content - - try: - if "displayname" not in content: - displayname = await profile.get_displayname(target) - if displayname is not None: - content["displayname"] = displayname - if "avatar_url" not in content: - avatar_url = await profile.get_avatar_url(target) - if avatar_url is not None: - content["avatar_url"] = avatar_url - except Exception as e: - logger.info( - "Failed to get profile information for %r: %s", target, e - ) - is_exempt = await self._is_exempt_from_privacy_policy(builder, requester) if require_consent and not is_exempt: await self.assert_accepted_privacy_policy(requester) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index b35dd84e6a..7311caf60f 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -106,6 +106,13 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.event_auth_handler = hs.get_event_auth_handler() self._worker_lock_handler = hs.get_worker_locks_handler() + self._membership_types_to_include_profile_data_in = { + Membership.JOIN, + Membership.KNOCK, + } + if self.hs.config.server.include_profile_data_on_invite: + self._membership_types_to_include_profile_data_in.add(Membership.INVITE) + self.member_linearizer: Linearizer = Linearizer(name="member") self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter") @@ -799,9 +806,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if ( not self.allow_per_room_profiles and not is_requester_server_notices_user ) or requester.shadow_banned: - # Strip profile data, knowing that new profile data will be added to the - # event's content in event_creation_handler.create_event() using the target's - # global profile. + # Strip profile data, knowing that new profile data will be added to + # the event's content below using the target's global profile. content.pop("displayname", None) content.pop("avatar_url", None) @@ -837,6 +843,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if action in ["kick", "unban"]: effective_membership_state = "leave" + if effective_membership_state not in Membership.LIST: + raise SynapseError(400, "Invalid membership key") + + # Add profile data for joins etc, if no per-room profile. + if ( + effective_membership_state + in self._membership_types_to_include_profile_data_in + ): + # If event doesn't include a display name, add one. + profile = self.profile_handler + + try: + if "displayname" not in content: + displayname = await profile.get_displayname(target) + if displayname is not None: + content["displayname"] = displayname + if "avatar_url" not in content: + avatar_url = await profile.get_avatar_url(target) + if avatar_url is not None: + content["avatar_url"] = avatar_url + except Exception as e: + logger.info("Failed to get profile information for %r: %s", target, e) + # if this is a join with a 3pid signature, we may need to turn a 3pid # invite into a normal invite before we can handle the join. if third_party_signed is not None: diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 84c4a5d324..4b1c8e81c4 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -58,6 +58,7 @@ from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.util.stringutils import parse_and_validate_server_name if TYPE_CHECKING: from synapse.server import HomeServer @@ -968,6 +969,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox" + CLEANUP_DEVICE_FEDERATION_OUTBOX = "cleanup_device_federation_outbox" def __init__( self, @@ -993,6 +995,11 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): self._remove_dead_devices_from_device_inbox, ) + self.db_pool.updates.register_background_update_handler( + self.CLEANUP_DEVICE_FEDERATION_OUTBOX, + self._cleanup_device_federation_outbox, + ) + async def _background_drop_index_device_inbox( self, progress: JsonDict, batch_size: int ) -> int: @@ -1084,6 +1091,75 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): return batch_size + async def _cleanup_device_federation_outbox( + self, + progress: JsonDict, + batch_size: int, + ) -> int: + def _cleanup_device_federation_outbox_txn( + txn: LoggingTransaction, + ) -> bool: + if "max_stream_id" in progress: + max_stream_id = progress["max_stream_id"] + else: + txn.execute("SELECT max(stream_id) FROM device_federation_outbox") + res = cast(Tuple[Optional[int]], txn.fetchone()) + if res[0] is None: + # this can only happen if the `device_inbox` table is empty, in which + # case we have no work to do. + return True + else: + max_stream_id = res[0] + + start = progress.get("stream_id", 0) + stop = start + batch_size + + sql = """ + SELECT destination FROM device_federation_outbox + WHERE ? < stream_id AND stream_id <= ? + """ + + txn.execute(sql, (start, stop)) + + destinations = {d for d, in txn} + to_remove = set() + for d in destinations: + try: + parse_and_validate_server_name(d) + except ValueError: + to_remove.add(d) + + self.db_pool.simple_delete_many_txn( + txn, + table="device_federation_outbox", + column="destination", + values=to_remove, + keyvalues={}, + ) + + self.db_pool.updates._background_update_progress_txn( + txn, + self.CLEANUP_DEVICE_FEDERATION_OUTBOX, + { + "stream_id": stop, + "max_stream_id": max_stream_id, + }, + ) + + return stop >= max_stream_id + + finished = await self.db_pool.runInteraction( + "_cleanup_device_federation_outbox", + _cleanup_device_federation_outbox_txn, + ) + + if finished: + await self.db_pool.updates._end_background_update( + self.CLEANUP_DEVICE_FEDERATION_OUTBOX, + ) + + return batch_size + class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): pass diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 7f6d1ab45c..934c98cbcd 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -57,10 +57,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.types import Cursor -from synapse.storage.util.id_generators import ( - AbstractStreamIdGenerator, - StreamIdGenerator, -) +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import ( JsonDict, JsonMapping, @@ -99,19 +96,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): # In the worker store this is an ID tracker which we overwrite in the non-worker # class below that is used on the main process. - self._device_list_id_gen = StreamIdGenerator( - db_conn, - hs.get_replication_notifier(), - "device_lists_stream", - "stream_id", - extra_tables=[ - ("user_signature_stream", "stream_id"), - ("device_lists_outbound_pokes", "stream_id"), - ("device_lists_changes_in_room", "stream_id"), - ("device_lists_remote_pending", "stream_id"), - ("device_lists_changes_converted_stream_position", "stream_id"), + self._device_list_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + notifier=hs.get_replication_notifier(), + stream_name="device_lists_stream", + instance_name=self._instance_name, + tables=[ + ("device_lists_stream", "instance_name", "stream_id"), + ("user_signature_stream", "instance_name", "stream_id"), + ("device_lists_outbound_pokes", "instance_name", "stream_id"), + ("device_lists_changes_in_room", "instance_name", "stream_id"), + ("device_lists_remote_pending", "instance_name", "stream_id"), ], - is_writer=hs.config.worker.worker_app is None, + sequence_name="device_lists_sequence", + writers=["master"], ) device_list_max = self._device_list_id_gen.get_current_token() @@ -764,6 +763,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): "stream_id": stream_id, "from_user_id": from_user_id, "user_ids": json_encoder.encode(user_ids), + "instance_name": self._instance_name, }, ) @@ -1584,6 +1584,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) + self._instance_name = hs.get_instance_name() + self.db_pool.updates.register_background_index_update( "device_lists_stream_idx", index_name="device_lists_stream_user_id", @@ -1696,6 +1698,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): "device_lists_outbound_pokes", { "stream_id": stream_id, + "instance_name": self._instance_name, "destination": destination, "user_id": user_id, "device_id": device_id, @@ -1732,10 +1735,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): - # Because we have write access, this will be a StreamIdGenerator - # (see DeviceWorkerStore.__init__) - _device_list_id_gen: AbstractStreamIdGenerator - def __init__( self, database: DatabasePool, @@ -2094,9 +2093,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.db_pool.simple_insert_many_txn( txn, table="device_lists_stream", - keys=("stream_id", "user_id", "device_id"), + keys=("instance_name", "stream_id", "user_id", "device_id"), values=[ - (stream_id, user_id, device_id) + (self._instance_name, stream_id, user_id, device_id) for stream_id, device_id in zip(stream_ids, device_ids) ], ) @@ -2126,6 +2125,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values = [ ( destination, + self._instance_name, next(stream_id_iterator), user_id, device_id, @@ -2141,6 +2141,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): table="device_lists_outbound_pokes", keys=( "destination", + "instance_name", "stream_id", "user_id", "device_id", @@ -2159,7 +2160,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): device_id, { stream_id: destination - for (destination, stream_id, _, _, _, _, _) in values + for (destination, _, stream_id, _, _, _, _, _) in values }, ) @@ -2212,6 +2213,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): "device_id", "room_id", "stream_id", + "instance_name", "converted_to_destinations", "opentracing_context", ), @@ -2221,6 +2223,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): device_id, room_id, stream_id, + self._instance_name, # We only need to calculate outbound pokes for local users not self.hs.is_mine_id(user_id), encoded_context, @@ -2340,7 +2343,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): "user_id": user_id, "device_id": device_id, }, - values={"stream_id": stream_id}, + values={ + "stream_id": stream_id, + "instance_name": self._instance_name, + }, desc="add_remote_device_list_to_pending", ) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index b219ea70ee..38d8785faa 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -58,7 +58,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import JsonDict, JsonMapping from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -1448,11 +1448,17 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): ): super().__init__(database, db_conn, hs) - self._cross_signing_id_gen = StreamIdGenerator( - db_conn, - hs.get_replication_notifier(), - "e2e_cross_signing_keys", - "stream_id", + self._cross_signing_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + notifier=hs.get_replication_notifier(), + stream_name="e2e_cross_signing_keys", + instance_name=self._instance_name, + tables=[ + ("e2e_cross_signing_keys", "instance_name", "stream_id"), + ], + sequence_name="e2e_cross_signing_keys_sequence", + writers=["master"], ) async def set_e2e_device_keys( @@ -1627,6 +1633,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): "keytype": key_type, "keydata": json_encoder.encode(key), "stream_id": stream_id, + "instance_name": self._instance_name, }, ) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index fd7167904d..f1bd85aa27 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -95,6 +95,10 @@ class DeltaState: to_insert: StateMap[str] no_longer_in_room: bool = False + def is_noop(self) -> bool: + """Whether this state delta is actually empty""" + return not self.to_delete and not self.to_insert and not self.no_longer_in_room + class PersistEventsStore: """Contains all the functions for writing events to the database. @@ -1017,6 +1021,9 @@ class PersistEventsStore: ) -> None: """Update the current state stored in the datatabase for the given room""" + if state_delta.is_noop(): + return + async with self._stream_id_gen.get_next() as stream_ordering: await self.db_pool.runInteraction( "update_current_state", diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index a1caf0d3e7..c389ff5713 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -200,7 +200,11 @@ class EventsWorkerStore(SQLBaseStore): notifier=hs.get_replication_notifier(), stream_name="events", instance_name=hs.get_instance_name(), - tables=[("events", "instance_name", "stream_ordering")], + tables=[ + ("events", "instance_name", "stream_ordering"), + ("current_state_delta_stream", "instance_name", "stream_id"), + ("ex_outlier_stream", "instance_name", "event_stream_ordering"), + ], sequence_name="events_stream_seq", writers=hs.config.worker.writers.events, ) @@ -210,7 +214,10 @@ class EventsWorkerStore(SQLBaseStore): notifier=hs.get_replication_notifier(), stream_name="backfill", instance_name=hs.get_instance_name(), - tables=[("events", "instance_name", "stream_ordering")], + tables=[ + ("events", "instance_name", "stream_ordering"), + ("ex_outlier_stream", "instance_name", "event_stream_ordering"), + ], sequence_name="events_backfill_stream_seq", positive=False, writers=hs.config.worker.writers.events, diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 660c834518..2a39dc9f90 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -53,7 +53,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException -from synapse.storage.util.id_generators import IdGenerator, StreamIdGenerator +from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules from synapse.types import JsonDict from synapse.util import json_encoder, unwrapFirstError @@ -126,7 +126,7 @@ class PushRulesWorkerStore( `get_max_push_rules_stream_id` which can be called in the initializer. """ - _push_rules_stream_id_gen: StreamIdGenerator + _push_rules_stream_id_gen: MultiWriterIdGenerator def __init__( self, @@ -140,14 +140,17 @@ class PushRulesWorkerStore( hs.get_instance_name() in hs.config.worker.writers.push_rules ) - # In the worker store this is an ID tracker which we overwrite in the non-worker - # class below that is used on the main process. - self._push_rules_stream_id_gen = StreamIdGenerator( - db_conn, - hs.get_replication_notifier(), - "push_rules_stream", - "stream_id", - is_writer=self._is_push_writer, + self._push_rules_stream_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + notifier=hs.get_replication_notifier(), + stream_name="push_rules_stream", + instance_name=self._instance_name, + tables=[ + ("push_rules_stream", "instance_name", "stream_id"), + ], + sequence_name="push_rules_stream_sequence", + writers=hs.config.worker.writers.push_rules, ) push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( @@ -880,6 +883,7 @@ class PushRulesWorkerStore( raise Exception("Not a push writer") values = { + "instance_name": self._instance_name, "stream_id": stream_id, "event_stream_ordering": event_stream_ordering, "user_id": user_id, diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 39e22d3b43..a8a37b6c85 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -40,10 +40,7 @@ from synapse.storage.database import ( LoggingDatabaseConnection, LoggingTransaction, ) -from synapse.storage.util.id_generators import ( - AbstractStreamIdGenerator, - StreamIdGenerator, -) +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -84,15 +81,20 @@ class PusherWorkerStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) - # In the worker store this is an ID tracker which we overwrite in the non-worker - # class below that is used on the main process. - self._pushers_id_gen = StreamIdGenerator( - db_conn, - hs.get_replication_notifier(), - "pushers", - "id", - extra_tables=[("deleted_pushers", "stream_id")], - is_writer=hs.config.worker.worker_app is None, + self._instance_name = hs.get_instance_name() + + self._pushers_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + notifier=hs.get_replication_notifier(), + stream_name="pushers", + instance_name=self._instance_name, + tables=[ + ("pushers", "instance_name", "id"), + ("deleted_pushers", "instance_name", "stream_id"), + ], + sequence_name="pushers_sequence", + writers=["master"], ) self.db_pool.updates.register_background_update_handler( @@ -655,7 +657,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore): class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore): # Because we have write access, this will be a StreamIdGenerator # (see PusherWorkerStore.__init__) - _pushers_id_gen: AbstractStreamIdGenerator + _pushers_id_gen: MultiWriterIdGenerator async def add_pusher( self, @@ -688,6 +690,7 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore): "last_stream_ordering": last_stream_ordering, "profile_tag": profile_tag, "id": stream_id, + "instance_name": self._instance_name, "enabled": enabled, "device_id": device_id, # XXX(quenting): We're only really persisting the access token ID @@ -735,6 +738,7 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore): table="deleted_pushers", values={ "stream_id": stream_id, + "instance_name": self._instance_name, "app_id": app_id, "pushkey": pushkey, "user_id": user_id, @@ -773,9 +777,15 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore): self.db_pool.simple_insert_many_txn( txn, table="deleted_pushers", - keys=("stream_id", "app_id", "pushkey", "user_id"), + keys=("stream_id", "instance_name", "app_id", "pushkey", "user_id"), values=[ - (stream_id, pusher.app_id, pusher.pushkey, user_id) + ( + stream_id, + self._instance_name, + pusher.app_id, + pusher.pushkey, + user_id, + ) for stream_id, pusher in zip(stream_ids, pushers) ], ) diff --git a/synapse/storage/schema/main/delta/85/02_add_instance_names.sql b/synapse/storage/schema/main/delta/85/02_add_instance_names.sql new file mode 100644 index 0000000000..d604595f73 --- /dev/null +++ b/synapse/storage/schema/main/delta/85/02_add_instance_names.sql @@ -0,0 +1,27 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2024 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +-- Add `instance_name` columns to stream tables to allow them to be used with +-- `MultiWriterIdGenerator` +ALTER TABLE device_lists_stream ADD COLUMN instance_name TEXT; +ALTER TABLE user_signature_stream ADD COLUMN instance_name TEXT; +ALTER TABLE device_lists_outbound_pokes ADD COLUMN instance_name TEXT; +ALTER TABLE device_lists_changes_in_room ADD COLUMN instance_name TEXT; +ALTER TABLE device_lists_remote_pending ADD COLUMN instance_name TEXT; + +ALTER TABLE e2e_cross_signing_keys ADD COLUMN instance_name TEXT; + +ALTER TABLE push_rules_stream ADD COLUMN instance_name TEXT; + +ALTER TABLE pushers ADD COLUMN instance_name TEXT; +ALTER TABLE deleted_pushers ADD COLUMN instance_name TEXT; diff --git a/synapse/storage/schema/main/delta/85/03_new_sequences.sql.postgres b/synapse/storage/schema/main/delta/85/03_new_sequences.sql.postgres new file mode 100644 index 0000000000..9d34066bf5 --- /dev/null +++ b/synapse/storage/schema/main/delta/85/03_new_sequences.sql.postgres @@ -0,0 +1,54 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2024 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +-- Add squences for stream tables to allow them to be used with +-- `MultiWriterIdGenerator` +CREATE SEQUENCE IF NOT EXISTS device_lists_sequence; + +-- We need to take the max across all the device lists tables as they share the +-- ID generator +SELECT setval('device_lists_sequence', ( + SELECT GREATEST( + (SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_stream), + (SELECT COALESCE(MAX(stream_id), 1) FROM user_signature_stream), + (SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_outbound_pokes), + (SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_in_room), + (SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_remote_pending), + (SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_converted_stream_position) + ) +)); + +CREATE SEQUENCE IF NOT EXISTS e2e_cross_signing_keys_sequence; + +SELECT setval('e2e_cross_signing_keys_sequence', ( + SELECT COALESCE(MAX(stream_id), 1) FROM e2e_cross_signing_keys +)); + + +CREATE SEQUENCE IF NOT EXISTS push_rules_stream_sequence; + +SELECT setval('push_rules_stream_sequence', ( + SELECT COALESCE(MAX(stream_id), 1) FROM push_rules_stream +)); + + +CREATE SEQUENCE IF NOT EXISTS pushers_sequence; + +-- We need to take the max across all the pusher tables as they share the +-- ID generator +SELECT setval('pushers_sequence', ( + SELECT GREATEST( + (SELECT COALESCE(MAX(id), 1) FROM pushers), + (SELECT COALESCE(MAX(stream_id), 1) FROM deleted_pushers) + ) +)); diff --git a/synapse/storage/schema/main/delta/85/04_cleanup_device_federation_outbox.sql b/synapse/storage/schema/main/delta/85/04_cleanup_device_federation_outbox.sql new file mode 100644 index 0000000000..041b17b0ee --- /dev/null +++ b/synapse/storage/schema/main/delta/85/04_cleanup_device_federation_outbox.sql @@ -0,0 +1,15 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2024 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (8504, 'cleanup_device_federation_outbox', '{}'); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 0cf5851ad7..59c8e05c39 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -23,15 +23,12 @@ import abc import heapq import logging import threading -from collections import OrderedDict -from contextlib import contextmanager from types import TracebackType from typing import ( TYPE_CHECKING, AsyncContextManager, ContextManager, Dict, - Generator, Generic, Iterable, List, @@ -179,161 +176,6 @@ class AbstractStreamIdGenerator(metaclass=abc.ABCMeta): raise NotImplementedError() -class StreamIdGenerator(AbstractStreamIdGenerator): - """Generates and tracks stream IDs for a stream with a single writer. - - This class must only be used when the current Synapse process is the sole - writer for a stream. - - Args: - db_conn(connection): A database connection to use to fetch the - initial value of the generator from. - table(str): A database table to read the initial value of the id - generator from. - column(str): The column of the database table to read the initial - value from the id generator from. - extra_tables(list): List of pairs of database tables and columns to - use to source the initial value of the generator from. The value - with the largest magnitude is used. - step(int): which direction the stream ids grow in. +1 to grow - upwards, -1 to grow downwards. - - Usage: - async with stream_id_gen.get_next() as stream_id: - # ... persist event ... - """ - - def __init__( - self, - db_conn: LoggingDatabaseConnection, - notifier: "ReplicationNotifier", - table: str, - column: str, - extra_tables: Iterable[Tuple[str, str]] = (), - step: int = 1, - is_writer: bool = True, - ) -> None: - assert step != 0 - self._lock = threading.Lock() - self._step: int = step - self._current: int = _load_current_id(db_conn, table, column, step) - self._is_writer = is_writer - for table, column in extra_tables: - self._current = (max if step > 0 else min)( - self._current, _load_current_id(db_conn, table, column, step) - ) - - # We use this as an ordered set, as we want to efficiently append items, - # remove items and get the first item. Since we insert IDs in order, the - # insertion ordering will ensure its in the correct ordering. - # - # The key and values are the same, but we never look at the values. - self._unfinished_ids: OrderedDict[int, int] = OrderedDict() - - self._notifier = notifier - - def advance(self, instance_name: str, new_id: int) -> None: - # Advance should never be called on a writer instance, only over replication - if self._is_writer: - raise Exception("Replication is not supported by writer StreamIdGenerator") - - self._current = (max if self._step > 0 else min)(self._current, new_id) - - def get_next(self) -> AsyncContextManager[int]: - with self._lock: - self._current += self._step - next_id = self._current - - self._unfinished_ids[next_id] = next_id - - @contextmanager - def manager() -> Generator[int, None, None]: - try: - yield next_id - finally: - with self._lock: - self._unfinished_ids.pop(next_id) - - self._notifier.notify_replication() - - return _AsyncCtxManagerWrapper(manager()) - - def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: - with self._lock: - next_ids = range( - self._current + self._step, - self._current + self._step * (n + 1), - self._step, - ) - self._current += n * self._step - - for next_id in next_ids: - self._unfinished_ids[next_id] = next_id - - @contextmanager - def manager() -> Generator[Sequence[int], None, None]: - try: - yield next_ids - finally: - with self._lock: - for next_id in next_ids: - self._unfinished_ids.pop(next_id) - - self._notifier.notify_replication() - - return _AsyncCtxManagerWrapper(manager()) - - def get_next_txn(self, txn: LoggingTransaction) -> int: - """ - Retrieve the next stream ID from within a database transaction. - - Clean-up functions will be called when the transaction finishes. - - Args: - txn: The database transaction object. - - Returns: - The next stream ID. - """ - if not self._is_writer: - raise Exception("Tried to allocate stream ID on non-writer") - - # Get the next stream ID. - with self._lock: - self._current += self._step - next_id = self._current - - self._unfinished_ids[next_id] = next_id - - def clear_unfinished_id(id_to_clear: int) -> None: - """A function to mark processing this ID as finished""" - with self._lock: - self._unfinished_ids.pop(id_to_clear) - - # Mark this ID as finished once the database transaction itself finishes. - txn.call_after(clear_unfinished_id, next_id) - txn.call_on_exception(clear_unfinished_id, next_id) - - # Return the new ID. - return next_id - - def get_current_token(self) -> int: - if not self._is_writer: - return self._current - - with self._lock: - if self._unfinished_ids: - return next(iter(self._unfinished_ids)) - self._step - - return self._current - - def get_current_token_for_writer(self, instance_name: str) -> int: - return self.get_current_token() - - def get_minimal_local_current_token(self) -> int: - return self.get_current_token() - - class MultiWriterIdGenerator(AbstractStreamIdGenerator): """Generates and tracks stream IDs for a stream with multiple writers. diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py index df43ce581c..213a66ed1a 100644 --- a/tests/handlers/test_room_member.py +++ b/tests/handlers/test_room_member.py @@ -407,3 +407,24 @@ class RoomMemberMasterHandlerTestCase(HomeserverTestCase): self.assertFalse( self.get_success(self.store.did_forget(self.alice, self.room_id)) ) + + def test_deduplicate_joins(self) -> None: + """ + Test that calling /join multiple times does not store a new state group. + """ + + self.helper.join(self.room_id, user=self.bob, tok=self.bob_token) + + sql = "SELECT COUNT(*) FROM state_groups WHERE room_id = ?" + rows = self.get_success( + self.store.db_pool.execute("test_deduplicate_joins", sql, self.room_id) + ) + initial_count = rows[0][0] + + self.helper.join(self.room_id, user=self.bob, tok=self.bob_token) + rows = self.get_success( + self.store.db_pool.execute("test_deduplicate_joins", sql, self.room_id) + ) + new_count = rows[0][0] + + self.assertEqual(initial_count, new_count) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index fad9511cea..f0307252f3 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -30,7 +30,7 @@ from synapse.storage.database import ( ) from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.types import Cursor -from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.storage.util.sequence import ( LocalSequenceGenerator, PostgresSequenceGenerator, @@ -42,144 +42,6 @@ from tests.unittest import HomeserverTestCase from tests.utils import USE_POSTGRES_FOR_TESTS -class StreamIdGeneratorTestCase(HomeserverTestCase): - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - self.db_pool: DatabasePool = self.store.db_pool - - self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) - - def _setup_db(self, txn: LoggingTransaction) -> None: - txn.execute( - """ - CREATE TABLE foobar ( - stream_id BIGINT NOT NULL, - data TEXT - ); - """ - ) - txn.execute("INSERT INTO foobar VALUES (123, 'hello world');") - - def _create_id_generator(self) -> StreamIdGenerator: - def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator: - return StreamIdGenerator( - db_conn=conn, - notifier=self.hs.get_replication_notifier(), - table="foobar", - column="stream_id", - ) - - return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) - - def test_initial_value(self) -> None: - """Check that we read the current token from the DB.""" - id_gen = self._create_id_generator() - self.assertEqual(id_gen.get_current_token(), 123) - - def test_single_gen_next(self) -> None: - """Check that we correctly increment the current token from the DB.""" - id_gen = self._create_id_generator() - - async def test_gen_next() -> None: - async with id_gen.get_next() as next_id: - # We haven't persisted `next_id` yet; current token is still 123 - self.assertEqual(id_gen.get_current_token(), 123) - # But we did learn what the next value is - self.assertEqual(next_id, 124) - - # Once the context manager closes we assume that the `next_id` has been - # written to the DB. - self.assertEqual(id_gen.get_current_token(), 124) - - self.get_success(test_gen_next()) - - def test_multiple_gen_nexts(self) -> None: - """Check that we handle overlapping calls to gen_next sensibly.""" - id_gen = self._create_id_generator() - - async def test_gen_next() -> None: - ctx1 = id_gen.get_next() - ctx2 = id_gen.get_next() - ctx3 = id_gen.get_next() - - # Request three new stream IDs. - self.assertEqual(await ctx1.__aenter__(), 124) - self.assertEqual(await ctx2.__aenter__(), 125) - self.assertEqual(await ctx3.__aenter__(), 126) - - # None are persisted: current token unchanged. - self.assertEqual(id_gen.get_current_token(), 123) - - # Persist each in turn. - await ctx1.__aexit__(None, None, None) - self.assertEqual(id_gen.get_current_token(), 124) - await ctx2.__aexit__(None, None, None) - self.assertEqual(id_gen.get_current_token(), 125) - await ctx3.__aexit__(None, None, None) - self.assertEqual(id_gen.get_current_token(), 126) - - self.get_success(test_gen_next()) - - def test_multiple_gen_nexts_closed_in_different_order(self) -> None: - """Check that we handle overlapping calls to gen_next, even when their IDs - created and persisted in different orders.""" - id_gen = self._create_id_generator() - - async def test_gen_next() -> None: - ctx1 = id_gen.get_next() - ctx2 = id_gen.get_next() - ctx3 = id_gen.get_next() - - # Request three new stream IDs. - self.assertEqual(await ctx1.__aenter__(), 124) - self.assertEqual(await ctx2.__aenter__(), 125) - self.assertEqual(await ctx3.__aenter__(), 126) - - # None are persisted: current token unchanged. - self.assertEqual(id_gen.get_current_token(), 123) - - # Persist them in a different order, starting with 126 from ctx3. - await ctx3.__aexit__(None, None, None) - # We haven't persisted 124 from ctx1 yet---current token is still 123. - self.assertEqual(id_gen.get_current_token(), 123) - - # Now persist 124 from ctx1. - await ctx1.__aexit__(None, None, None) - # Current token is then 124, waiting for 125 to be persisted. - self.assertEqual(id_gen.get_current_token(), 124) - - # Finally persist 125 from ctx2. - await ctx2.__aexit__(None, None, None) - # Current token is then 126 (skipping over 125). - self.assertEqual(id_gen.get_current_token(), 126) - - self.get_success(test_gen_next()) - - def test_gen_next_while_still_waiting_for_persistence(self) -> None: - """Check that we handle overlapping calls to gen_next.""" - id_gen = self._create_id_generator() - - async def test_gen_next() -> None: - ctx1 = id_gen.get_next() - ctx2 = id_gen.get_next() - ctx3 = id_gen.get_next() - - # Request two new stream IDs. - self.assertEqual(await ctx1.__aenter__(), 124) - self.assertEqual(await ctx2.__aenter__(), 125) - - # Persist ctx2 first. - await ctx2.__aexit__(None, None, None) - # Still waiting on ctx1's ID to be persisted. - self.assertEqual(id_gen.get_current_token(), 123) - - # Now request a third stream ID. It should be 126 (the smallest ID that - # we've not yet handed out.) - self.assertEqual(await ctx3.__aenter__(), 126) - - self.get_success(test_gen_next()) - - class MultiWriterIdGeneratorBase(HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main