mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-20 02:16:01 +03:00
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
This commit is contained in:
commit
293eeffb0c
19 changed files with 389 additions and 400 deletions
1
changelog.d/17164.bugfix
Normal file
1
changelog.d/17164.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix deduplicating of membership events to not create unused state groups.
|
1
changelog.d/17229.misc
Normal file
1
changelog.d/17229.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Replaces all usages of `StreamIdGenerator` with `MultiWriterIdGenerator`.
|
1
changelog.d/17242.misc
Normal file
1
changelog.d/17242.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Clean out invalid destinations from `device_federation_outbox` table.
|
|
@ -777,22 +777,74 @@ class Porter:
|
|||
await self._setup_events_stream_seqs()
|
||||
await self._setup_sequence(
|
||||
"un_partial_stated_event_stream_sequence",
|
||||
("un_partial_stated_event_stream",),
|
||||
[("un_partial_stated_event_stream", "stream_id")],
|
||||
)
|
||||
await self._setup_sequence(
|
||||
"device_inbox_sequence", ("device_inbox", "device_federation_outbox")
|
||||
"device_inbox_sequence",
|
||||
[
|
||||
("device_inbox", "stream_id"),
|
||||
("device_federation_outbox", "stream_id"),
|
||||
],
|
||||
)
|
||||
await self._setup_sequence(
|
||||
"account_data_sequence",
|
||||
("room_account_data", "room_tags_revisions", "account_data"),
|
||||
[
|
||||
("room_account_data", "stream_id"),
|
||||
("room_tags_revisions", "stream_id"),
|
||||
("account_data", "stream_id"),
|
||||
],
|
||||
)
|
||||
await self._setup_sequence(
|
||||
"receipts_sequence",
|
||||
[
|
||||
("receipts_linearized", "stream_id"),
|
||||
],
|
||||
)
|
||||
await self._setup_sequence(
|
||||
"presence_stream_sequence",
|
||||
[
|
||||
("presence_stream", "stream_id"),
|
||||
],
|
||||
)
|
||||
await self._setup_sequence("receipts_sequence", ("receipts_linearized",))
|
||||
await self._setup_sequence("presence_stream_sequence", ("presence_stream",))
|
||||
await self._setup_auth_chain_sequence()
|
||||
await self._setup_sequence(
|
||||
"application_services_txn_id_seq",
|
||||
("application_services_txns",),
|
||||
"txn_id",
|
||||
[
|
||||
(
|
||||
"application_services_txns",
|
||||
"txn_id",
|
||||
)
|
||||
],
|
||||
)
|
||||
await self._setup_sequence(
|
||||
"device_lists_sequence",
|
||||
[
|
||||
("device_lists_stream", "stream_id"),
|
||||
("user_signature_stream", "stream_id"),
|
||||
("device_lists_outbound_pokes", "stream_id"),
|
||||
("device_lists_changes_in_room", "stream_id"),
|
||||
("device_lists_remote_pending", "stream_id"),
|
||||
("device_lists_changes_converted_stream_position", "stream_id"),
|
||||
],
|
||||
)
|
||||
await self._setup_sequence(
|
||||
"e2e_cross_signing_keys_sequence",
|
||||
[
|
||||
("e2e_cross_signing_keys", "stream_id"),
|
||||
],
|
||||
)
|
||||
await self._setup_sequence(
|
||||
"push_rules_stream_sequence",
|
||||
[
|
||||
("push_rules_stream", "stream_id"),
|
||||
],
|
||||
)
|
||||
await self._setup_sequence(
|
||||
"pushers_sequence",
|
||||
[
|
||||
("pushers", "id"),
|
||||
("deleted_pushers", "stream_id"),
|
||||
],
|
||||
)
|
||||
|
||||
# Step 3. Get tables.
|
||||
|
@ -1101,12 +1153,11 @@ class Porter:
|
|||
async def _setup_sequence(
|
||||
self,
|
||||
sequence_name: str,
|
||||
stream_id_tables: Iterable[str],
|
||||
column_name: str = "stream_id",
|
||||
stream_id_tables: Iterable[Tuple[str, str]],
|
||||
) -> None:
|
||||
"""Set a sequence to the correct value."""
|
||||
current_stream_ids = []
|
||||
for stream_id_table in stream_id_tables:
|
||||
for stream_id_table, column_name in stream_id_tables:
|
||||
max_stream_id = cast(
|
||||
int,
|
||||
await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||
|
|
|
@ -496,13 +496,6 @@ class EventCreationHandler:
|
|||
|
||||
self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
|
||||
|
||||
self.membership_types_to_include_profile_data_in = {
|
||||
Membership.JOIN,
|
||||
Membership.KNOCK,
|
||||
}
|
||||
if self.hs.config.server.include_profile_data_on_invite:
|
||||
self.membership_types_to_include_profile_data_in.add(Membership.INVITE)
|
||||
|
||||
self.send_event = ReplicationSendEventRestServlet.make_client(hs)
|
||||
self.send_events = ReplicationSendEventsRestServlet.make_client(hs)
|
||||
|
||||
|
@ -594,8 +587,6 @@ class EventCreationHandler:
|
|||
Creates an FrozenEvent object, filling out auth_events, prev_events,
|
||||
etc.
|
||||
|
||||
Adds display names to Join membership events.
|
||||
|
||||
Args:
|
||||
requester
|
||||
event_dict: An entire event
|
||||
|
@ -672,29 +663,6 @@ class EventCreationHandler:
|
|||
|
||||
self.validator.validate_builder(builder)
|
||||
|
||||
if builder.type == EventTypes.Member:
|
||||
membership = builder.content.get("membership", None)
|
||||
target = UserID.from_string(builder.state_key)
|
||||
|
||||
if membership in self.membership_types_to_include_profile_data_in:
|
||||
# If event doesn't include a display name, add one.
|
||||
profile = self.profile_handler
|
||||
content = builder.content
|
||||
|
||||
try:
|
||||
if "displayname" not in content:
|
||||
displayname = await profile.get_displayname(target)
|
||||
if displayname is not None:
|
||||
content["displayname"] = displayname
|
||||
if "avatar_url" not in content:
|
||||
avatar_url = await profile.get_avatar_url(target)
|
||||
if avatar_url is not None:
|
||||
content["avatar_url"] = avatar_url
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"Failed to get profile information for %r: %s", target, e
|
||||
)
|
||||
|
||||
is_exempt = await self._is_exempt_from_privacy_policy(builder, requester)
|
||||
if require_consent and not is_exempt:
|
||||
await self.assert_accepted_privacy_policy(requester)
|
||||
|
|
|
@ -106,6 +106,13 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
self.event_auth_handler = hs.get_event_auth_handler()
|
||||
self._worker_lock_handler = hs.get_worker_locks_handler()
|
||||
|
||||
self._membership_types_to_include_profile_data_in = {
|
||||
Membership.JOIN,
|
||||
Membership.KNOCK,
|
||||
}
|
||||
if self.hs.config.server.include_profile_data_on_invite:
|
||||
self._membership_types_to_include_profile_data_in.add(Membership.INVITE)
|
||||
|
||||
self.member_linearizer: Linearizer = Linearizer(name="member")
|
||||
self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter")
|
||||
|
||||
|
@ -799,9 +806,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
if (
|
||||
not self.allow_per_room_profiles and not is_requester_server_notices_user
|
||||
) or requester.shadow_banned:
|
||||
# Strip profile data, knowing that new profile data will be added to the
|
||||
# event's content in event_creation_handler.create_event() using the target's
|
||||
# global profile.
|
||||
# Strip profile data, knowing that new profile data will be added to
|
||||
# the event's content below using the target's global profile.
|
||||
content.pop("displayname", None)
|
||||
content.pop("avatar_url", None)
|
||||
|
||||
|
@ -837,6 +843,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
if action in ["kick", "unban"]:
|
||||
effective_membership_state = "leave"
|
||||
|
||||
if effective_membership_state not in Membership.LIST:
|
||||
raise SynapseError(400, "Invalid membership key")
|
||||
|
||||
# Add profile data for joins etc, if no per-room profile.
|
||||
if (
|
||||
effective_membership_state
|
||||
in self._membership_types_to_include_profile_data_in
|
||||
):
|
||||
# If event doesn't include a display name, add one.
|
||||
profile = self.profile_handler
|
||||
|
||||
try:
|
||||
if "displayname" not in content:
|
||||
displayname = await profile.get_displayname(target)
|
||||
if displayname is not None:
|
||||
content["displayname"] = displayname
|
||||
if "avatar_url" not in content:
|
||||
avatar_url = await profile.get_avatar_url(target)
|
||||
if avatar_url is not None:
|
||||
content["avatar_url"] = avatar_url
|
||||
except Exception as e:
|
||||
logger.info("Failed to get profile information for %r: %s", target, e)
|
||||
|
||||
# if this is a join with a 3pid signature, we may need to turn a 3pid
|
||||
# invite into a normal invite before we can handle the join.
|
||||
if third_party_signed is not None:
|
||||
|
|
|
@ -58,6 +58,7 @@ from synapse.types import JsonDict
|
|||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
from synapse.util.stringutils import parse_and_validate_server_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
@ -968,6 +969,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
|
||||
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
|
||||
REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"
|
||||
CLEANUP_DEVICE_FEDERATION_OUTBOX = "cleanup_device_federation_outbox"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -993,6 +995,11 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
|
|||
self._remove_dead_devices_from_device_inbox,
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
self.CLEANUP_DEVICE_FEDERATION_OUTBOX,
|
||||
self._cleanup_device_federation_outbox,
|
||||
)
|
||||
|
||||
async def _background_drop_index_device_inbox(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
|
@ -1084,6 +1091,75 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
|
|||
|
||||
return batch_size
|
||||
|
||||
async def _cleanup_device_federation_outbox(
|
||||
self,
|
||||
progress: JsonDict,
|
||||
batch_size: int,
|
||||
) -> int:
|
||||
def _cleanup_device_federation_outbox_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> bool:
|
||||
if "max_stream_id" in progress:
|
||||
max_stream_id = progress["max_stream_id"]
|
||||
else:
|
||||
txn.execute("SELECT max(stream_id) FROM device_federation_outbox")
|
||||
res = cast(Tuple[Optional[int]], txn.fetchone())
|
||||
if res[0] is None:
|
||||
# this can only happen if the `device_inbox` table is empty, in which
|
||||
# case we have no work to do.
|
||||
return True
|
||||
else:
|
||||
max_stream_id = res[0]
|
||||
|
||||
start = progress.get("stream_id", 0)
|
||||
stop = start + batch_size
|
||||
|
||||
sql = """
|
||||
SELECT destination FROM device_federation_outbox
|
||||
WHERE ? < stream_id AND stream_id <= ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (start, stop))
|
||||
|
||||
destinations = {d for d, in txn}
|
||||
to_remove = set()
|
||||
for d in destinations:
|
||||
try:
|
||||
parse_and_validate_server_name(d)
|
||||
except ValueError:
|
||||
to_remove.add(d)
|
||||
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="device_federation_outbox",
|
||||
column="destination",
|
||||
values=to_remove,
|
||||
keyvalues={},
|
||||
)
|
||||
|
||||
self.db_pool.updates._background_update_progress_txn(
|
||||
txn,
|
||||
self.CLEANUP_DEVICE_FEDERATION_OUTBOX,
|
||||
{
|
||||
"stream_id": stop,
|
||||
"max_stream_id": max_stream_id,
|
||||
},
|
||||
)
|
||||
|
||||
return stop >= max_stream_id
|
||||
|
||||
finished = await self.db_pool.runInteraction(
|
||||
"_cleanup_device_federation_outbox",
|
||||
_cleanup_device_federation_outbox_txn,
|
||||
)
|
||||
|
||||
if finished:
|
||||
await self.db_pool.updates._end_background_update(
|
||||
self.CLEANUP_DEVICE_FEDERATION_OUTBOX,
|
||||
)
|
||||
|
||||
return batch_size
|
||||
|
||||
|
||||
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
|
||||
pass
|
||||
|
|
|
@ -57,10 +57,7 @@ from synapse.storage.database import (
|
|||
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
|
||||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.id_generators import (
|
||||
AbstractStreamIdGenerator,
|
||||
StreamIdGenerator,
|
||||
)
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
JsonMapping,
|
||||
|
@ -99,19 +96,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
|
||||
# In the worker store this is an ID tracker which we overwrite in the non-worker
|
||||
# class below that is used on the main process.
|
||||
self._device_list_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
hs.get_replication_notifier(),
|
||||
"device_lists_stream",
|
||||
"stream_id",
|
||||
extra_tables=[
|
||||
("user_signature_stream", "stream_id"),
|
||||
("device_lists_outbound_pokes", "stream_id"),
|
||||
("device_lists_changes_in_room", "stream_id"),
|
||||
("device_lists_remote_pending", "stream_id"),
|
||||
("device_lists_changes_converted_stream_position", "stream_id"),
|
||||
self._device_list_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="device_lists_stream",
|
||||
instance_name=self._instance_name,
|
||||
tables=[
|
||||
("device_lists_stream", "instance_name", "stream_id"),
|
||||
("user_signature_stream", "instance_name", "stream_id"),
|
||||
("device_lists_outbound_pokes", "instance_name", "stream_id"),
|
||||
("device_lists_changes_in_room", "instance_name", "stream_id"),
|
||||
("device_lists_remote_pending", "instance_name", "stream_id"),
|
||||
],
|
||||
is_writer=hs.config.worker.worker_app is None,
|
||||
sequence_name="device_lists_sequence",
|
||||
writers=["master"],
|
||||
)
|
||||
|
||||
device_list_max = self._device_list_id_gen.get_current_token()
|
||||
|
@ -764,6 +763,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
"stream_id": stream_id,
|
||||
"from_user_id": from_user_id,
|
||||
"user_ids": json_encoder.encode(user_ids),
|
||||
"instance_name": self._instance_name,
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -1584,6 +1584,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
|||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
"device_lists_stream_idx",
|
||||
index_name="device_lists_stream_user_id",
|
||||
|
@ -1696,6 +1698,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
|||
"device_lists_outbound_pokes",
|
||||
{
|
||||
"stream_id": stream_id,
|
||||
"instance_name": self._instance_name,
|
||||
"destination": destination,
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
|
@ -1732,10 +1735,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
|||
|
||||
|
||||
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
# Because we have write access, this will be a StreamIdGenerator
|
||||
# (see DeviceWorkerStore.__init__)
|
||||
_device_list_id_gen: AbstractStreamIdGenerator
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
|
@ -2094,9 +2093,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
self.db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
table="device_lists_stream",
|
||||
keys=("stream_id", "user_id", "device_id"),
|
||||
keys=("instance_name", "stream_id", "user_id", "device_id"),
|
||||
values=[
|
||||
(stream_id, user_id, device_id)
|
||||
(self._instance_name, stream_id, user_id, device_id)
|
||||
for stream_id, device_id in zip(stream_ids, device_ids)
|
||||
],
|
||||
)
|
||||
|
@ -2126,6 +2125,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
values = [
|
||||
(
|
||||
destination,
|
||||
self._instance_name,
|
||||
next(stream_id_iterator),
|
||||
user_id,
|
||||
device_id,
|
||||
|
@ -2141,6 +2141,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
table="device_lists_outbound_pokes",
|
||||
keys=(
|
||||
"destination",
|
||||
"instance_name",
|
||||
"stream_id",
|
||||
"user_id",
|
||||
"device_id",
|
||||
|
@ -2159,7 +2160,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
device_id,
|
||||
{
|
||||
stream_id: destination
|
||||
for (destination, stream_id, _, _, _, _, _) in values
|
||||
for (destination, _, stream_id, _, _, _, _, _) in values
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -2212,6 +2213,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
"device_id",
|
||||
"room_id",
|
||||
"stream_id",
|
||||
"instance_name",
|
||||
"converted_to_destinations",
|
||||
"opentracing_context",
|
||||
),
|
||||
|
@ -2221,6 +2223,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
device_id,
|
||||
room_id,
|
||||
stream_id,
|
||||
self._instance_name,
|
||||
# We only need to calculate outbound pokes for local users
|
||||
not self.hs.is_mine_id(user_id),
|
||||
encoded_context,
|
||||
|
@ -2340,7 +2343,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
values={"stream_id": stream_id},
|
||||
values={
|
||||
"stream_id": stream_id,
|
||||
"instance_name": self._instance_name,
|
||||
},
|
||||
desc="add_remote_device_list_to_pending",
|
||||
)
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ from synapse.storage.database import (
|
|||
)
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.types import JsonDict, JsonMapping
|
||||
from synapse.util import json_decoder, json_encoder
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
|
@ -1448,11 +1448,17 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._cross_signing_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
hs.get_replication_notifier(),
|
||||
"e2e_cross_signing_keys",
|
||||
"stream_id",
|
||||
self._cross_signing_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="e2e_cross_signing_keys",
|
||||
instance_name=self._instance_name,
|
||||
tables=[
|
||||
("e2e_cross_signing_keys", "instance_name", "stream_id"),
|
||||
],
|
||||
sequence_name="e2e_cross_signing_keys_sequence",
|
||||
writers=["master"],
|
||||
)
|
||||
|
||||
async def set_e2e_device_keys(
|
||||
|
@ -1627,6 +1633,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
"keytype": key_type,
|
||||
"keydata": json_encoder.encode(key),
|
||||
"stream_id": stream_id,
|
||||
"instance_name": self._instance_name,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
@ -95,6 +95,10 @@ class DeltaState:
|
|||
to_insert: StateMap[str]
|
||||
no_longer_in_room: bool = False
|
||||
|
||||
def is_noop(self) -> bool:
|
||||
"""Whether this state delta is actually empty"""
|
||||
return not self.to_delete and not self.to_insert and not self.no_longer_in_room
|
||||
|
||||
|
||||
class PersistEventsStore:
|
||||
"""Contains all the functions for writing events to the database.
|
||||
|
@ -1017,6 +1021,9 @@ class PersistEventsStore:
|
|||
) -> None:
|
||||
"""Update the current state stored in the datatabase for the given room"""
|
||||
|
||||
if state_delta.is_noop():
|
||||
return
|
||||
|
||||
async with self._stream_id_gen.get_next() as stream_ordering:
|
||||
await self.db_pool.runInteraction(
|
||||
"update_current_state",
|
||||
|
|
|
@ -200,7 +200,11 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="events",
|
||||
instance_name=hs.get_instance_name(),
|
||||
tables=[("events", "instance_name", "stream_ordering")],
|
||||
tables=[
|
||||
("events", "instance_name", "stream_ordering"),
|
||||
("current_state_delta_stream", "instance_name", "stream_id"),
|
||||
("ex_outlier_stream", "instance_name", "event_stream_ordering"),
|
||||
],
|
||||
sequence_name="events_stream_seq",
|
||||
writers=hs.config.worker.writers.events,
|
||||
)
|
||||
|
@ -210,7 +214,10 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="backfill",
|
||||
instance_name=hs.get_instance_name(),
|
||||
tables=[("events", "instance_name", "stream_ordering")],
|
||||
tables=[
|
||||
("events", "instance_name", "stream_ordering"),
|
||||
("ex_outlier_stream", "instance_name", "event_stream_ordering"),
|
||||
],
|
||||
sequence_name="events_backfill_stream_seq",
|
||||
positive=False,
|
||||
writers=hs.config.worker.writers.events,
|
||||
|
|
|
@ -53,7 +53,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
|
|||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
|
||||
from synapse.storage.util.id_generators import IdGenerator, StreamIdGenerator
|
||||
from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator
|
||||
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder, unwrapFirstError
|
||||
|
@ -126,7 +126,7 @@ class PushRulesWorkerStore(
|
|||
`get_max_push_rules_stream_id` which can be called in the initializer.
|
||||
"""
|
||||
|
||||
_push_rules_stream_id_gen: StreamIdGenerator
|
||||
_push_rules_stream_id_gen: MultiWriterIdGenerator
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -140,14 +140,17 @@ class PushRulesWorkerStore(
|
|||
hs.get_instance_name() in hs.config.worker.writers.push_rules
|
||||
)
|
||||
|
||||
# In the worker store this is an ID tracker which we overwrite in the non-worker
|
||||
# class below that is used on the main process.
|
||||
self._push_rules_stream_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
hs.get_replication_notifier(),
|
||||
"push_rules_stream",
|
||||
"stream_id",
|
||||
is_writer=self._is_push_writer,
|
||||
self._push_rules_stream_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="push_rules_stream",
|
||||
instance_name=self._instance_name,
|
||||
tables=[
|
||||
("push_rules_stream", "instance_name", "stream_id"),
|
||||
],
|
||||
sequence_name="push_rules_stream_sequence",
|
||||
writers=hs.config.worker.writers.push_rules,
|
||||
)
|
||||
|
||||
push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
|
||||
|
@ -880,6 +883,7 @@ class PushRulesWorkerStore(
|
|||
raise Exception("Not a push writer")
|
||||
|
||||
values = {
|
||||
"instance_name": self._instance_name,
|
||||
"stream_id": stream_id,
|
||||
"event_stream_ordering": event_stream_ordering,
|
||||
"user_id": user_id,
|
||||
|
|
|
@ -40,10 +40,7 @@ from synapse.storage.database import (
|
|||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.util.id_generators import (
|
||||
AbstractStreamIdGenerator,
|
||||
StreamIdGenerator,
|
||||
)
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
@ -84,15 +81,20 @@ class PusherWorkerStore(SQLBaseStore):
|
|||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
# In the worker store this is an ID tracker which we overwrite in the non-worker
|
||||
# class below that is used on the main process.
|
||||
self._pushers_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
hs.get_replication_notifier(),
|
||||
"pushers",
|
||||
"id",
|
||||
extra_tables=[("deleted_pushers", "stream_id")],
|
||||
is_writer=hs.config.worker.worker_app is None,
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
self._pushers_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="pushers",
|
||||
instance_name=self._instance_name,
|
||||
tables=[
|
||||
("pushers", "instance_name", "id"),
|
||||
("deleted_pushers", "instance_name", "stream_id"),
|
||||
],
|
||||
sequence_name="pushers_sequence",
|
||||
writers=["master"],
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
|
@ -655,7 +657,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
|
|||
class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
|
||||
# Because we have write access, this will be a StreamIdGenerator
|
||||
# (see PusherWorkerStore.__init__)
|
||||
_pushers_id_gen: AbstractStreamIdGenerator
|
||||
_pushers_id_gen: MultiWriterIdGenerator
|
||||
|
||||
async def add_pusher(
|
||||
self,
|
||||
|
@ -688,6 +690,7 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
|
|||
"last_stream_ordering": last_stream_ordering,
|
||||
"profile_tag": profile_tag,
|
||||
"id": stream_id,
|
||||
"instance_name": self._instance_name,
|
||||
"enabled": enabled,
|
||||
"device_id": device_id,
|
||||
# XXX(quenting): We're only really persisting the access token ID
|
||||
|
@ -735,6 +738,7 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
|
|||
table="deleted_pushers",
|
||||
values={
|
||||
"stream_id": stream_id,
|
||||
"instance_name": self._instance_name,
|
||||
"app_id": app_id,
|
||||
"pushkey": pushkey,
|
||||
"user_id": user_id,
|
||||
|
@ -773,9 +777,15 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
|
|||
self.db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
table="deleted_pushers",
|
||||
keys=("stream_id", "app_id", "pushkey", "user_id"),
|
||||
keys=("stream_id", "instance_name", "app_id", "pushkey", "user_id"),
|
||||
values=[
|
||||
(stream_id, pusher.app_id, pusher.pushkey, user_id)
|
||||
(
|
||||
stream_id,
|
||||
self._instance_name,
|
||||
pusher.app_id,
|
||||
pusher.pushkey,
|
||||
user_id,
|
||||
)
|
||||
for stream_id, pusher in zip(stream_ids, pushers)
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2024 New Vector, Ltd
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
-- Add `instance_name` columns to stream tables to allow them to be used with
|
||||
-- `MultiWriterIdGenerator`
|
||||
ALTER TABLE device_lists_stream ADD COLUMN instance_name TEXT;
|
||||
ALTER TABLE user_signature_stream ADD COLUMN instance_name TEXT;
|
||||
ALTER TABLE device_lists_outbound_pokes ADD COLUMN instance_name TEXT;
|
||||
ALTER TABLE device_lists_changes_in_room ADD COLUMN instance_name TEXT;
|
||||
ALTER TABLE device_lists_remote_pending ADD COLUMN instance_name TEXT;
|
||||
|
||||
ALTER TABLE e2e_cross_signing_keys ADD COLUMN instance_name TEXT;
|
||||
|
||||
ALTER TABLE push_rules_stream ADD COLUMN instance_name TEXT;
|
||||
|
||||
ALTER TABLE pushers ADD COLUMN instance_name TEXT;
|
||||
ALTER TABLE deleted_pushers ADD COLUMN instance_name TEXT;
|
|
@ -0,0 +1,54 @@
|
|||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2024 New Vector, Ltd
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
-- Add squences for stream tables to allow them to be used with
|
||||
-- `MultiWriterIdGenerator`
|
||||
CREATE SEQUENCE IF NOT EXISTS device_lists_sequence;
|
||||
|
||||
-- We need to take the max across all the device lists tables as they share the
|
||||
-- ID generator
|
||||
SELECT setval('device_lists_sequence', (
|
||||
SELECT GREATEST(
|
||||
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_stream),
|
||||
(SELECT COALESCE(MAX(stream_id), 1) FROM user_signature_stream),
|
||||
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_outbound_pokes),
|
||||
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_in_room),
|
||||
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_remote_pending),
|
||||
(SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_converted_stream_position)
|
||||
)
|
||||
));
|
||||
|
||||
CREATE SEQUENCE IF NOT EXISTS e2e_cross_signing_keys_sequence;
|
||||
|
||||
SELECT setval('e2e_cross_signing_keys_sequence', (
|
||||
SELECT COALESCE(MAX(stream_id), 1) FROM e2e_cross_signing_keys
|
||||
));
|
||||
|
||||
|
||||
CREATE SEQUENCE IF NOT EXISTS push_rules_stream_sequence;
|
||||
|
||||
SELECT setval('push_rules_stream_sequence', (
|
||||
SELECT COALESCE(MAX(stream_id), 1) FROM push_rules_stream
|
||||
));
|
||||
|
||||
|
||||
CREATE SEQUENCE IF NOT EXISTS pushers_sequence;
|
||||
|
||||
-- We need to take the max across all the pusher tables as they share the
|
||||
-- ID generator
|
||||
SELECT setval('pushers_sequence', (
|
||||
SELECT GREATEST(
|
||||
(SELECT COALESCE(MAX(id), 1) FROM pushers),
|
||||
(SELECT COALESCE(MAX(stream_id), 1) FROM deleted_pushers)
|
||||
)
|
||||
));
|
|
@ -0,0 +1,15 @@
|
|||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2024 New Vector, Ltd
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(8504, 'cleanup_device_federation_outbox', '{}');
|
|
@ -23,15 +23,12 @@ import abc
|
|||
import heapq
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AsyncContextManager,
|
||||
ContextManager,
|
||||
Dict,
|
||||
Generator,
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
|
@ -179,161 +176,6 @@ class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
|
|||
raise NotImplementedError()
|
||||
|
||||
|
||||
class StreamIdGenerator(AbstractStreamIdGenerator):
|
||||
"""Generates and tracks stream IDs for a stream with a single writer.
|
||||
|
||||
This class must only be used when the current Synapse process is the sole
|
||||
writer for a stream.
|
||||
|
||||
Args:
|
||||
db_conn(connection): A database connection to use to fetch the
|
||||
initial value of the generator from.
|
||||
table(str): A database table to read the initial value of the id
|
||||
generator from.
|
||||
column(str): The column of the database table to read the initial
|
||||
value from the id generator from.
|
||||
extra_tables(list): List of pairs of database tables and columns to
|
||||
use to source the initial value of the generator from. The value
|
||||
with the largest magnitude is used.
|
||||
step(int): which direction the stream ids grow in. +1 to grow
|
||||
upwards, -1 to grow downwards.
|
||||
|
||||
Usage:
|
||||
async with stream_id_gen.get_next() as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
notifier: "ReplicationNotifier",
|
||||
table: str,
|
||||
column: str,
|
||||
extra_tables: Iterable[Tuple[str, str]] = (),
|
||||
step: int = 1,
|
||||
is_writer: bool = True,
|
||||
) -> None:
|
||||
assert step != 0
|
||||
self._lock = threading.Lock()
|
||||
self._step: int = step
|
||||
self._current: int = _load_current_id(db_conn, table, column, step)
|
||||
self._is_writer = is_writer
|
||||
for table, column in extra_tables:
|
||||
self._current = (max if step > 0 else min)(
|
||||
self._current, _load_current_id(db_conn, table, column, step)
|
||||
)
|
||||
|
||||
# We use this as an ordered set, as we want to efficiently append items,
|
||||
# remove items and get the first item. Since we insert IDs in order, the
|
||||
# insertion ordering will ensure its in the correct ordering.
|
||||
#
|
||||
# The key and values are the same, but we never look at the values.
|
||||
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
|
||||
|
||||
self._notifier = notifier
|
||||
|
||||
def advance(self, instance_name: str, new_id: int) -> None:
|
||||
# Advance should never be called on a writer instance, only over replication
|
||||
if self._is_writer:
|
||||
raise Exception("Replication is not supported by writer StreamIdGenerator")
|
||||
|
||||
self._current = (max if self._step > 0 else min)(self._current, new_id)
|
||||
|
||||
def get_next(self) -> AsyncContextManager[int]:
|
||||
with self._lock:
|
||||
self._current += self._step
|
||||
next_id = self._current
|
||||
|
||||
self._unfinished_ids[next_id] = next_id
|
||||
|
||||
@contextmanager
|
||||
def manager() -> Generator[int, None, None]:
|
||||
try:
|
||||
yield next_id
|
||||
finally:
|
||||
with self._lock:
|
||||
self._unfinished_ids.pop(next_id)
|
||||
|
||||
self._notifier.notify_replication()
|
||||
|
||||
return _AsyncCtxManagerWrapper(manager())
|
||||
|
||||
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
|
||||
with self._lock:
|
||||
next_ids = range(
|
||||
self._current + self._step,
|
||||
self._current + self._step * (n + 1),
|
||||
self._step,
|
||||
)
|
||||
self._current += n * self._step
|
||||
|
||||
for next_id in next_ids:
|
||||
self._unfinished_ids[next_id] = next_id
|
||||
|
||||
@contextmanager
|
||||
def manager() -> Generator[Sequence[int], None, None]:
|
||||
try:
|
||||
yield next_ids
|
||||
finally:
|
||||
with self._lock:
|
||||
for next_id in next_ids:
|
||||
self._unfinished_ids.pop(next_id)
|
||||
|
||||
self._notifier.notify_replication()
|
||||
|
||||
return _AsyncCtxManagerWrapper(manager())
|
||||
|
||||
def get_next_txn(self, txn: LoggingTransaction) -> int:
|
||||
"""
|
||||
Retrieve the next stream ID from within a database transaction.
|
||||
|
||||
Clean-up functions will be called when the transaction finishes.
|
||||
|
||||
Args:
|
||||
txn: The database transaction object.
|
||||
|
||||
Returns:
|
||||
The next stream ID.
|
||||
"""
|
||||
if not self._is_writer:
|
||||
raise Exception("Tried to allocate stream ID on non-writer")
|
||||
|
||||
# Get the next stream ID.
|
||||
with self._lock:
|
||||
self._current += self._step
|
||||
next_id = self._current
|
||||
|
||||
self._unfinished_ids[next_id] = next_id
|
||||
|
||||
def clear_unfinished_id(id_to_clear: int) -> None:
|
||||
"""A function to mark processing this ID as finished"""
|
||||
with self._lock:
|
||||
self._unfinished_ids.pop(id_to_clear)
|
||||
|
||||
# Mark this ID as finished once the database transaction itself finishes.
|
||||
txn.call_after(clear_unfinished_id, next_id)
|
||||
txn.call_on_exception(clear_unfinished_id, next_id)
|
||||
|
||||
# Return the new ID.
|
||||
return next_id
|
||||
|
||||
def get_current_token(self) -> int:
|
||||
if not self._is_writer:
|
||||
return self._current
|
||||
|
||||
with self._lock:
|
||||
if self._unfinished_ids:
|
||||
return next(iter(self._unfinished_ids)) - self._step
|
||||
|
||||
return self._current
|
||||
|
||||
def get_current_token_for_writer(self, instance_name: str) -> int:
|
||||
return self.get_current_token()
|
||||
|
||||
def get_minimal_local_current_token(self) -> int:
|
||||
return self.get_current_token()
|
||||
|
||||
|
||||
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||
"""Generates and tracks stream IDs for a stream with multiple writers.
|
||||
|
||||
|
|
|
@ -407,3 +407,24 @@ class RoomMemberMasterHandlerTestCase(HomeserverTestCase):
|
|||
self.assertFalse(
|
||||
self.get_success(self.store.did_forget(self.alice, self.room_id))
|
||||
)
|
||||
|
||||
def test_deduplicate_joins(self) -> None:
|
||||
"""
|
||||
Test that calling /join multiple times does not store a new state group.
|
||||
"""
|
||||
|
||||
self.helper.join(self.room_id, user=self.bob, tok=self.bob_token)
|
||||
|
||||
sql = "SELECT COUNT(*) FROM state_groups WHERE room_id = ?"
|
||||
rows = self.get_success(
|
||||
self.store.db_pool.execute("test_deduplicate_joins", sql, self.room_id)
|
||||
)
|
||||
initial_count = rows[0][0]
|
||||
|
||||
self.helper.join(self.room_id, user=self.bob, tok=self.bob_token)
|
||||
rows = self.get_success(
|
||||
self.store.db_pool.execute("test_deduplicate_joins", sql, self.room_id)
|
||||
)
|
||||
new_count = rows[0][0]
|
||||
|
||||
self.assertEqual(initial_count, new_count)
|
||||
|
|
|
@ -30,7 +30,7 @@ from synapse.storage.database import (
|
|||
)
|
||||
from synapse.storage.engines import IncorrectDatabaseSetup
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.storage.util.sequence import (
|
||||
LocalSequenceGenerator,
|
||||
PostgresSequenceGenerator,
|
||||
|
@ -42,144 +42,6 @@ from tests.unittest import HomeserverTestCase
|
|||
from tests.utils import USE_POSTGRES_FOR_TESTS
|
||||
|
||||
|
||||
class StreamIdGeneratorTestCase(HomeserverTestCase):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = hs.get_datastores().main
|
||||
self.db_pool: DatabasePool = self.store.db_pool
|
||||
|
||||
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
|
||||
|
||||
def _setup_db(self, txn: LoggingTransaction) -> None:
|
||||
txn.execute(
|
||||
"""
|
||||
CREATE TABLE foobar (
|
||||
stream_id BIGINT NOT NULL,
|
||||
data TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
txn.execute("INSERT INTO foobar VALUES (123, 'hello world');")
|
||||
|
||||
def _create_id_generator(self) -> StreamIdGenerator:
|
||||
def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator:
|
||||
return StreamIdGenerator(
|
||||
db_conn=conn,
|
||||
notifier=self.hs.get_replication_notifier(),
|
||||
table="foobar",
|
||||
column="stream_id",
|
||||
)
|
||||
|
||||
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
|
||||
|
||||
def test_initial_value(self) -> None:
|
||||
"""Check that we read the current token from the DB."""
|
||||
id_gen = self._create_id_generator()
|
||||
self.assertEqual(id_gen.get_current_token(), 123)
|
||||
|
||||
def test_single_gen_next(self) -> None:
|
||||
"""Check that we correctly increment the current token from the DB."""
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
async def test_gen_next() -> None:
|
||||
async with id_gen.get_next() as next_id:
|
||||
# We haven't persisted `next_id` yet; current token is still 123
|
||||
self.assertEqual(id_gen.get_current_token(), 123)
|
||||
# But we did learn what the next value is
|
||||
self.assertEqual(next_id, 124)
|
||||
|
||||
# Once the context manager closes we assume that the `next_id` has been
|
||||
# written to the DB.
|
||||
self.assertEqual(id_gen.get_current_token(), 124)
|
||||
|
||||
self.get_success(test_gen_next())
|
||||
|
||||
def test_multiple_gen_nexts(self) -> None:
|
||||
"""Check that we handle overlapping calls to gen_next sensibly."""
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
async def test_gen_next() -> None:
|
||||
ctx1 = id_gen.get_next()
|
||||
ctx2 = id_gen.get_next()
|
||||
ctx3 = id_gen.get_next()
|
||||
|
||||
# Request three new stream IDs.
|
||||
self.assertEqual(await ctx1.__aenter__(), 124)
|
||||
self.assertEqual(await ctx2.__aenter__(), 125)
|
||||
self.assertEqual(await ctx3.__aenter__(), 126)
|
||||
|
||||
# None are persisted: current token unchanged.
|
||||
self.assertEqual(id_gen.get_current_token(), 123)
|
||||
|
||||
# Persist each in turn.
|
||||
await ctx1.__aexit__(None, None, None)
|
||||
self.assertEqual(id_gen.get_current_token(), 124)
|
||||
await ctx2.__aexit__(None, None, None)
|
||||
self.assertEqual(id_gen.get_current_token(), 125)
|
||||
await ctx3.__aexit__(None, None, None)
|
||||
self.assertEqual(id_gen.get_current_token(), 126)
|
||||
|
||||
self.get_success(test_gen_next())
|
||||
|
||||
def test_multiple_gen_nexts_closed_in_different_order(self) -> None:
|
||||
"""Check that we handle overlapping calls to gen_next, even when their IDs
|
||||
created and persisted in different orders."""
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
async def test_gen_next() -> None:
|
||||
ctx1 = id_gen.get_next()
|
||||
ctx2 = id_gen.get_next()
|
||||
ctx3 = id_gen.get_next()
|
||||
|
||||
# Request three new stream IDs.
|
||||
self.assertEqual(await ctx1.__aenter__(), 124)
|
||||
self.assertEqual(await ctx2.__aenter__(), 125)
|
||||
self.assertEqual(await ctx3.__aenter__(), 126)
|
||||
|
||||
# None are persisted: current token unchanged.
|
||||
self.assertEqual(id_gen.get_current_token(), 123)
|
||||
|
||||
# Persist them in a different order, starting with 126 from ctx3.
|
||||
await ctx3.__aexit__(None, None, None)
|
||||
# We haven't persisted 124 from ctx1 yet---current token is still 123.
|
||||
self.assertEqual(id_gen.get_current_token(), 123)
|
||||
|
||||
# Now persist 124 from ctx1.
|
||||
await ctx1.__aexit__(None, None, None)
|
||||
# Current token is then 124, waiting for 125 to be persisted.
|
||||
self.assertEqual(id_gen.get_current_token(), 124)
|
||||
|
||||
# Finally persist 125 from ctx2.
|
||||
await ctx2.__aexit__(None, None, None)
|
||||
# Current token is then 126 (skipping over 125).
|
||||
self.assertEqual(id_gen.get_current_token(), 126)
|
||||
|
||||
self.get_success(test_gen_next())
|
||||
|
||||
def test_gen_next_while_still_waiting_for_persistence(self) -> None:
|
||||
"""Check that we handle overlapping calls to gen_next."""
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
async def test_gen_next() -> None:
|
||||
ctx1 = id_gen.get_next()
|
||||
ctx2 = id_gen.get_next()
|
||||
ctx3 = id_gen.get_next()
|
||||
|
||||
# Request two new stream IDs.
|
||||
self.assertEqual(await ctx1.__aenter__(), 124)
|
||||
self.assertEqual(await ctx2.__aenter__(), 125)
|
||||
|
||||
# Persist ctx2 first.
|
||||
await ctx2.__aexit__(None, None, None)
|
||||
# Still waiting on ctx1's ID to be persisted.
|
||||
self.assertEqual(id_gen.get_current_token(), 123)
|
||||
|
||||
# Now request a third stream ID. It should be 126 (the smallest ID that
|
||||
# we've not yet handed out.)
|
||||
self.assertEqual(await ctx3.__aenter__(), 126)
|
||||
|
||||
self.get_success(test_gen_next())
|
||||
|
||||
|
||||
class MultiWriterIdGeneratorBase(HomeserverTestCase):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = hs.get_datastores().main
|
||||
|
|
Loading…
Reference in a new issue