Move back to the main store

This commit is contained in:
Eric Eastwood 2024-08-21 11:14:15 -05:00
parent 0233e20aa3
commit a5e06c6a8d
3 changed files with 475 additions and 475 deletions

View file

@ -24,9 +24,9 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast
import attr
from synapse.api.constants import EventContentFields, RelationTypes
from synapse.api.constants import EventContentFields, Membership, RelationTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict
from synapse.events import EventBase, make_event_from_dict
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@ -34,9 +34,18 @@ from synapse.storage.database import (
LoggingTransaction,
make_tuple_comparison_clause,
)
from synapse.storage.databases.main.events import PersistEventsStore
from synapse.storage.databases.main.events import (
SLIDING_SYNC_RELEVANT_STATE_SET,
PersistEventsStore,
SlidingSyncMembershipInfo,
SlidingSyncMembershipSnapshotSharedInsertValues,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.types import Cursor
from synapse.types import JsonDict, StrCollection
from synapse.types import JsonDict, StateMap, StrCollection
from synapse.types.handlers import SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES
from synapse.types.state import StateFilter
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -78,6 +87,11 @@ class _BackgroundUpdates:
EVENTS_JUMP_TO_DATE_INDEX = "events_jump_to_date_index"
SLIDING_SYNC_JOINED_ROOMS_BACKFILL = "sliding_sync_joined_rooms_backfill"
SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BACKFILL = (
"sliding_sync_membership_snapshots_backfill"
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _CalculateChainCover:
@ -97,7 +111,7 @@ class _CalculateChainCover:
finished_room_map: Dict[str, Tuple[int, int]]
class EventsBackgroundUpdatesStore(SQLBaseStore):
class EventsBackgroundUpdatesStore(EventsWorkerStore, SQLBaseStore):
def __init__(
self,
database: DatabasePool,
@ -279,6 +293,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
where_clause="NOT outlier",
)
# Backfill the sliding sync tables
self.db_pool.updates.register_background_update_handler(
_BackgroundUpdates.SLIDING_SYNC_JOINED_ROOMS_BACKFILL,
self._sliding_sync_joined_rooms_backfill,
)
self.db_pool.updates.register_background_update_handler(
_BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BACKFILL,
self._sliding_sync_membership_snapshots_backfill,
)
async def _background_reindex_fields_sender(
self, progress: JsonDict, batch_size: int
) -> int:
@ -1073,7 +1097,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
PersistEventsStore._add_chain_cover_index(
txn,
self.db_pool,
self.event_chain_id_gen, # type: ignore[attr-defined]
self.event_chain_id_gen,
event_to_room_id,
event_to_types,
cast(Dict[str, StrCollection], event_to_auth_chain),
@ -1516,3 +1540,443 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
)
return batch_size
async def _sliding_sync_joined_rooms_backfill(
self, progress: JsonDict, batch_size: int
) -> int:
"""
Handles backfilling the `sliding_sync_joined_rooms` table.
"""
last_room_id = progress.get("last_room_id", "")
def make_sql_clause_for_get_last_event_pos_in_room(
database_engine: BaseDatabaseEngine,
event_types: Optional[StrCollection] = None,
) -> Tuple[str, list]:
"""
Returns the ID and event position of the last event in a room at or before a
stream ordering.
Based on `get_last_event_pos_in_room_before_stream_ordering(...)`
Args:
database_engine
event_types: Optional allowlist of event types to filter by
Returns:
A tuple of SQL query and the args
"""
event_type_clause = ""
event_type_args: List[str] = []
if event_types is not None and len(event_types) > 0:
event_type_clause, event_type_args = make_in_list_sql_clause(
database_engine, "type", event_types
)
event_type_clause = f"AND {event_type_clause}"
sql = f"""
SELECT stream_ordering
FROM events
LEFT JOIN rejections USING (event_id)
WHERE room_id = ?
{event_type_clause}
AND NOT outlier
AND rejections.event_id IS NULL
ORDER BY stream_ordering DESC
LIMIT 1
"""
return sql, event_type_args
def _txn(txn: LoggingTransaction) -> int:
# Fetch the set of room IDs that we want to update
txn.execute(
"""
SELECT DISTINCT room_id FROM current_state_events
WHERE room_id > ?
ORDER BY room_id ASC
LIMIT ?
""",
(last_room_id, batch_size),
)
rooms_to_update_rows = txn.fetchall()
if not rooms_to_update_rows:
return 0
for (room_id,) in rooms_to_update_rows:
# TODO: Handle redactions
current_state_map = PersistEventsStore._get_relevant_sliding_sync_current_state_event_ids_txn(
txn, room_id
)
# We're iterating over rooms pulled from the current_state_events table
# so we should have some current state for each room
assert current_state_map
sliding_sync_joined_rooms_insert_map = PersistEventsStore._get_sliding_sync_insert_values_from_state_ids_map_txn(
txn, current_state_map
)
# We should have some insert values for each room, even if they are `None`
assert sliding_sync_joined_rooms_insert_map
(
most_recent_event_stream_ordering_clause,
most_recent_event_stream_ordering_args,
) = make_sql_clause_for_get_last_event_pos_in_room(
txn.database_engine, event_types=None
)
bump_stamp_clause, bump_stamp_args = (
make_sql_clause_for_get_last_event_pos_in_room(
txn.database_engine,
event_types=SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES,
)
)
# Pulling keys/values separately is safe and will produce congruent
# lists
insert_keys = sliding_sync_joined_rooms_insert_map.keys()
insert_values = sliding_sync_joined_rooms_insert_map.values()
sql = f"""
INSERT INTO sliding_sync_joined_rooms
(room_id, event_stream_ordering, bump_stamp, {", ".join(insert_keys)})
VALUES (
?,
({most_recent_event_stream_ordering_clause}),
({bump_stamp_clause}),
{", ".join("?" for _ in insert_values)}
)
ON CONFLICT (room_id)
DO UPDATE SET
event_stream_ordering = EXCLUDED.event_stream_ordering,
bump_stamp = EXCLUDED.bump_stamp,
{", ".join(f"{key} = EXCLUDED.{key}" for key in insert_keys)}
"""
args = (
[room_id, room_id]
+ most_recent_event_stream_ordering_args
+ [room_id]
+ bump_stamp_args
+ list(insert_values)
)
txn.execute(sql, args)
self.db_pool.updates._background_update_progress_txn(
txn,
_BackgroundUpdates.SLIDING_SYNC_JOINED_ROOMS_BACKFILL,
{"last_room_id": rooms_to_update_rows[-1][0]},
)
return len(rooms_to_update_rows)
count = await self.db_pool.runInteraction(
"sliding_sync_joined_rooms_backfill", _txn
)
if not count:
await self.db_pool.updates._end_background_update(
_BackgroundUpdates.SLIDING_SYNC_JOINED_ROOMS_BACKFILL
)
return count
async def _sliding_sync_membership_snapshots_backfill(
self, progress: JsonDict, batch_size: int
) -> int:
"""
Handles backfilling the `sliding_sync_membership_snapshots` table.
"""
last_event_stream_ordering = progress.get(
"last_event_stream_ordering", -(1 << 31)
)
def _find_memberships_to_update_txn(
txn: LoggingTransaction,
) -> List[Tuple[str, str, str, str, str, int, bool]]:
# Fetch the set of event IDs that we want to update
txn.execute(
"""
SELECT
c.room_id,
c.user_id,
e.sender,
c.event_id,
c.membership,
c.event_stream_ordering,
e.outlier
FROM local_current_membership as c
INNER JOIN events AS e USING (event_id)
WHERE event_stream_ordering > ?
ORDER BY event_stream_ordering ASC
LIMIT ?
""",
(last_event_stream_ordering, batch_size),
)
memberships_to_update_rows = cast(
List[Tuple[str, str, str, str, str, int, bool]], txn.fetchall()
)
return memberships_to_update_rows
memberships_to_update_rows = await self.db_pool.runInteraction(
"sliding_sync_membership_snapshots_backfill._find_memberships_to_update_txn",
_find_memberships_to_update_txn,
)
if not memberships_to_update_rows:
await self.db_pool.updates._end_background_update(
_BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BACKFILL
)
return 0
def _find_previous_membership_txn(
txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
) -> Tuple[str, str]:
# Find the previous invite/knock event before the leave event
txn.execute(
"""
SELECT event_id, membership
FROM room_memberships
WHERE
room_id = ?
AND user_id = ?
AND event_stream_ordering < ?
ORDER BY event_stream_ordering DESC
LIMIT 1
""",
(
room_id,
user_id,
stream_ordering,
),
)
row = txn.fetchone()
# We should see a corresponding previous invite/knock event
assert row is not None
event_id, membership = row
return event_id, membership
# Map from (room_id, user_id) to ...
to_insert_membership_snapshots: Dict[
Tuple[str, str], SlidingSyncMembershipSnapshotSharedInsertValues
] = {}
to_insert_membership_infos: Dict[Tuple[str, str], SlidingSyncMembershipInfo] = (
{}
)
for (
room_id,
user_id,
sender,
membership_event_id,
membership,
membership_event_stream_ordering,
is_outlier,
) in memberships_to_update_rows:
# We don't know how to handle `membership` values other than these. The
# code below would need to be updated.
assert membership in (
Membership.JOIN,
Membership.INVITE,
Membership.KNOCK,
Membership.LEAVE,
Membership.BAN,
)
# Map of values to insert/update in the `sliding_sync_membership_snapshots` table
sliding_sync_membership_snapshots_insert_map: (
SlidingSyncMembershipSnapshotSharedInsertValues
) = {}
if membership == Membership.JOIN:
# If we're still joined, we can pull from current state.
current_state_ids_map: StateMap[
str
] = await self.hs.get_storage_controllers().state.get_current_state_ids(
room_id,
state_filter=StateFilter.from_types(
SLIDING_SYNC_RELEVANT_STATE_SET
),
# Partially-stated rooms should have all state events except for
# remote membership events so we don't need to wait at all because
# we only want some non-membership state
await_full_state=False,
)
# We're iterating over rooms that we are joined to so they should
# have `current_state_events` and we should have some current state
# for each room
assert current_state_ids_map
fetched_events = await self.get_events(current_state_ids_map.values())
current_state_map: StateMap[EventBase] = {
state_key: fetched_events[event_id]
for state_key, event_id in current_state_ids_map.items()
}
state_insert_values = (
PersistEventsStore._get_sliding_sync_insert_values_from_state_map(
current_state_map
)
)
sliding_sync_membership_snapshots_insert_map.update(state_insert_values)
# We should have some insert values for each room, even if they are `None`
assert sliding_sync_membership_snapshots_insert_map
# We have current state to work from
sliding_sync_membership_snapshots_insert_map["has_known_state"] = True
elif membership in (Membership.INVITE, Membership.KNOCK) or (
membership == Membership.LEAVE and is_outlier
):
invite_or_knock_event_id = membership_event_id
invite_or_knock_membership = membership
# If the event is an `out_of_band_membership` (special case of
# `outlier`), we never had historical state so we have to pull from
# the stripped state on the previous invite/knock event. This gives
# us a consistent view of the room state regardless of your
# membership (i.e. the room shouldn't disappear if your using the
# `is_encrypted` filter and you leave).
if membership == Membership.LEAVE and is_outlier:
invite_or_knock_event_id, invite_or_knock_membership = (
await self.db_pool.runInteraction(
"sliding_sync_membership_snapshots_backfill._find_previous_membership",
_find_previous_membership_txn,
room_id,
user_id,
membership_event_stream_ordering,
)
)
# Pull from the stripped state on the invite/knock event
invite_or_knock_event = await self.get_event(invite_or_knock_event_id)
raw_stripped_state_events = None
if invite_or_knock_membership == Membership.INVITE:
invite_room_state = invite_or_knock_event.unsigned.get(
"invite_room_state"
)
raw_stripped_state_events = invite_room_state
elif invite_or_knock_membership == Membership.KNOCK:
knock_room_state = invite_or_knock_event.unsigned.get(
"knock_room_state"
)
raw_stripped_state_events = knock_room_state
sliding_sync_membership_snapshots_insert_map = await self.db_pool.runInteraction(
"sliding_sync_membership_snapshots_backfill._get_sliding_sync_insert_values_from_stripped_state_txn",
PersistEventsStore._get_sliding_sync_insert_values_from_stripped_state_txn,
raw_stripped_state_events,
)
# We should have some insert values for each room, even if no
# stripped state is on the event because we still want to record
# that we have no known state
assert sliding_sync_membership_snapshots_insert_map
elif membership in (Membership.LEAVE, Membership.BAN):
# Pull from historical state
state_ids_map = await self.hs.get_storage_controllers().state.get_state_ids_for_event(
membership_event_id,
state_filter=StateFilter.from_types(
SLIDING_SYNC_RELEVANT_STATE_SET
),
# Partially-stated rooms should have all state events except for
# remote membership events so we don't need to wait at all because
# we only want some non-membership state
await_full_state=False,
)
fetched_events = await self.get_events(state_ids_map.values())
state_map: StateMap[EventBase] = {
state_key: fetched_events[event_id]
for state_key, event_id in state_ids_map.items()
}
state_insert_values = (
PersistEventsStore._get_sliding_sync_insert_values_from_state_map(
state_map
)
)
sliding_sync_membership_snapshots_insert_map.update(state_insert_values)
# We should have some insert values for each room, even if they are `None`
assert sliding_sync_membership_snapshots_insert_map
# We have historical state to work from
sliding_sync_membership_snapshots_insert_map["has_known_state"] = True
else:
# We don't know how to handle this type of membership yet
#
# FIXME: We should use `assert_never` here but for some reason
# the exhaustive matching doesn't recognize the `Never` here.
# assert_never(membership)
raise AssertionError(
f"Unexpected membership {membership} ({membership_event_id}) that we don't know how to handle yet"
)
to_insert_membership_snapshots[(room_id, user_id)] = (
sliding_sync_membership_snapshots_insert_map
)
to_insert_membership_infos[(room_id, user_id)] = SlidingSyncMembershipInfo(
user_id=user_id,
sender=sender,
membership_event_id=membership_event_id,
)
def _backfill_table_txn(txn: LoggingTransaction) -> None:
for key, insert_map in to_insert_membership_snapshots.items():
room_id, user_id = key
membership_info = to_insert_membership_infos[key]
membership_event_id = membership_info.membership_event_id
# Pulling keys/values separately is safe and will produce congruent
# lists
insert_keys = insert_map.keys()
insert_values = insert_map.values()
# We don't need to do anything `ON CONFLICT` because we never partially
# insert/update the snapshots
txn.execute(
f"""
INSERT INTO sliding_sync_membership_snapshots
(room_id, user_id, membership_event_id, membership, event_stream_ordering
{("," + ", ".join(insert_keys)) if insert_keys else ""})
VALUES (
?, ?, ?,
(SELECT membership FROM room_memberships WHERE event_id = ?),
(SELECT stream_ordering FROM events WHERE event_id = ?)
{("," + ", ".join("?" for _ in insert_values)) if insert_values else ""}
)
ON CONFLICT (room_id, user_id)
DO NOTHING
""",
[
room_id,
user_id,
membership_event_id,
membership_event_id,
membership_event_id,
]
+ list(insert_values),
)
await self.db_pool.runInteraction(
"sliding_sync_membership_snapshots_backfill", _backfill_table_txn
)
# Update the progress
(
_room_id,
_user_id,
_sender,
_membership_event_id,
_membership,
membership_event_stream_ordering,
_is_outlier,
) = memberships_to_update_rows[-1]
await self.db_pool.updates._background_update_progress(
_BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BACKFILL,
{"last_event_stream_ordering": membership_event_stream_ordering},
)
return len(memberships_to_update_rows)

View file

@ -20,28 +20,17 @@
#
import logging
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
from typing_extensions import assert_never
from synapse.api.constants import Membership
from synapse.events import EventBase
from synapse.logging.opentracing import tag_args, trace
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.events import (
SLIDING_SYNC_RELEVANT_STATE_SET,
PersistEventsStore,
SlidingSyncMembershipInfo,
SlidingSyncMembershipSnapshotSharedInsertValues,
)
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.types import JsonDict, MutableStateMap, StateMap, StrCollection
from synapse.types.handlers import SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES
from synapse.storage.engines import PostgresEngine
from synapse.types import MutableStateMap, StateMap
from synapse.types.state import StateFilter
from synapse.util.caches import intern_string
@ -54,13 +43,6 @@ logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
class _BackgroundUpdates:
SLIDING_SYNC_JOINED_ROOMS_BACKFILL = "sliding_sync_joined_rooms_backfill"
SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BACKFILL = (
"sliding_sync_membership_snapshots_backfill"
)
class StateGroupBackgroundUpdateStore(SQLBaseStore):
"""Defines functions related to state groups needed to run the state background
updates.
@ -367,16 +349,6 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
columns=["event_stream_ordering"],
)
# Backfill the sliding sync tables
self.db_pool.updates.register_background_update_handler(
_BackgroundUpdates.SLIDING_SYNC_JOINED_ROOMS_BACKFILL,
self._sliding_sync_joined_rooms_backfill,
)
self.db_pool.updates.register_background_update_handler(
_BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BACKFILL,
self._sliding_sync_membership_snapshots_backfill,
)
async def _background_deduplicate_state(
self, progress: dict, batch_size: int
) -> int:
@ -552,439 +524,3 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
)
return 1
async def _sliding_sync_joined_rooms_backfill(
self, progress: JsonDict, batch_size: int
) -> int:
"""
Handles backfilling the `sliding_sync_joined_rooms` table.
"""
last_room_id = progress.get("last_room_id", "")
def make_sql_clause_for_get_last_event_pos_in_room(
database_engine: BaseDatabaseEngine,
event_types: Optional[StrCollection] = None,
) -> Tuple[str, list]:
"""
Returns the ID and event position of the last event in a room at or before a
stream ordering.
Based on `get_last_event_pos_in_room_before_stream_ordering(...)`
Args:
database_engine
event_types: Optional allowlist of event types to filter by
Returns:
A tuple of SQL query and the args
"""
event_type_clause = ""
event_type_args: List[str] = []
if event_types is not None and len(event_types) > 0:
event_type_clause, event_type_args = make_in_list_sql_clause(
database_engine, "type", event_types
)
event_type_clause = f"AND {event_type_clause}"
sql = f"""
SELECT stream_ordering
FROM events
LEFT JOIN rejections USING (event_id)
WHERE room_id = ?
{event_type_clause}
AND NOT outlier
AND rejections.event_id IS NULL
ORDER BY stream_ordering DESC
LIMIT 1
"""
return sql, event_type_args
def _txn(txn: LoggingTransaction) -> int:
# Fetch the set of room IDs that we want to update
txn.execute(
"""
SELECT DISTINCT room_id FROM current_state_events
WHERE room_id > ?
ORDER BY room_id ASC
LIMIT ?
""",
(last_room_id, batch_size),
)
rooms_to_update_rows = txn.fetchall()
if not rooms_to_update_rows:
return 0
for (room_id,) in rooms_to_update_rows:
# TODO: Handle redactions
current_state_map = PersistEventsStore._get_relevant_sliding_sync_current_state_event_ids_txn(
txn, room_id
)
# We're iterating over rooms pulled from the current_state_events table
# so we should have some current state for each room
assert current_state_map
sliding_sync_joined_rooms_insert_map = PersistEventsStore._get_sliding_sync_insert_values_from_state_ids_map_txn(
txn, current_state_map
)
# We should have some insert values for each room, even if they are `None`
assert sliding_sync_joined_rooms_insert_map
(
most_recent_event_stream_ordering_clause,
most_recent_event_stream_ordering_args,
) = make_sql_clause_for_get_last_event_pos_in_room(
txn.database_engine, event_types=None
)
bump_stamp_clause, bump_stamp_args = (
make_sql_clause_for_get_last_event_pos_in_room(
txn.database_engine,
event_types=SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES,
)
)
# Pulling keys/values separately is safe and will produce congruent
# lists
insert_keys = sliding_sync_joined_rooms_insert_map.keys()
insert_values = sliding_sync_joined_rooms_insert_map.values()
sql = f"""
INSERT INTO sliding_sync_joined_rooms
(room_id, event_stream_ordering, bump_stamp, {", ".join(insert_keys)})
VALUES (
?,
({most_recent_event_stream_ordering_clause}),
({bump_stamp_clause}),
{", ".join("?" for _ in insert_values)}
)
ON CONFLICT (room_id)
DO UPDATE SET
event_stream_ordering = EXCLUDED.event_stream_ordering,
bump_stamp = EXCLUDED.bump_stamp,
{", ".join(f"{key} = EXCLUDED.{key}" for key in insert_keys)}
"""
args = (
[room_id, room_id]
+ most_recent_event_stream_ordering_args
+ [room_id]
+ bump_stamp_args
+ list(insert_values)
)
txn.execute(sql, args)
self.db_pool.updates._background_update_progress_txn(
txn,
_BackgroundUpdates.SLIDING_SYNC_JOINED_ROOMS_BACKFILL,
{"last_room_id": rooms_to_update_rows[-1][0]},
)
return len(rooms_to_update_rows)
count = await self.db_pool.runInteraction(
"sliding_sync_joined_rooms_backfill", _txn
)
if not count:
await self.db_pool.updates._end_background_update(
_BackgroundUpdates.SLIDING_SYNC_JOINED_ROOMS_BACKFILL
)
return count
async def _sliding_sync_membership_snapshots_backfill(
self, progress: JsonDict, batch_size: int
) -> int:
"""
Handles backfilling the `sliding_sync_membership_snapshots` table.
"""
last_event_stream_ordering = progress.get(
"last_event_stream_ordering", -(1 << 31)
)
def _find_memberships_to_update_txn(
txn: LoggingTransaction,
) -> List[Tuple[str, str, str, str, str, int, bool]]:
# Fetch the set of event IDs that we want to update
txn.execute(
"""
SELECT
c.room_id,
c.user_id,
e.sender,
c.event_id,
c.membership,
c.event_stream_ordering,
e.outlier
FROM local_current_membership as c
INNER JOIN events AS e USING (event_id)
WHERE event_stream_ordering > ?
ORDER BY event_stream_ordering ASC
LIMIT ?
""",
(last_event_stream_ordering, batch_size),
)
memberships_to_update_rows = cast(
List[Tuple[str, str, str, str, str, int, bool]], txn.fetchall()
)
return memberships_to_update_rows
memberships_to_update_rows = await self.db_pool.runInteraction(
"sliding_sync_membership_snapshots_backfill._find_memberships_to_update_txn",
_find_memberships_to_update_txn,
)
if not memberships_to_update_rows:
await self.db_pool.updates._end_background_update(
_BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BACKFILL
)
return 0
store = self.hs.get_storage_controllers().main
def _find_previous_membership_txn(
txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
) -> Tuple[str, str]:
# Find the previous invite/knock event before the leave event
txn.execute(
"""
SELECT event_id, membership
FROM room_memberships
WHERE
room_id = ?
AND user_id = ?
AND event_stream_ordering < ?
ORDER BY event_stream_ordering DESC
LIMIT 1
""",
(
room_id,
user_id,
stream_ordering,
),
)
row = txn.fetchone()
# We should see a corresponding previous invite/knock event
assert row is not None
event_id, membership = row
return event_id, membership
# Map from (room_id, user_id) to ...
to_insert_membership_snapshots: Dict[
Tuple[str, str], SlidingSyncMembershipSnapshotSharedInsertValues
] = {}
to_insert_membership_infos: Dict[Tuple[str, str], SlidingSyncMembershipInfo] = (
{}
)
for (
room_id,
user_id,
sender,
membership_event_id,
membership,
membership_event_stream_ordering,
is_outlier,
) in memberships_to_update_rows:
# We don't know how to handle `membership` values other than these. The
# code below would need to be updated.
assert membership in (
Membership.JOIN,
Membership.INVITE,
Membership.KNOCK,
Membership.LEAVE,
Membership.BAN,
)
# Map of values to insert/update in the `sliding_sync_membership_snapshots` table
sliding_sync_membership_snapshots_insert_map: (
SlidingSyncMembershipSnapshotSharedInsertValues
) = {}
if membership == Membership.JOIN:
# If we're still joined, we can pull from current state.
current_state_ids_map: StateMap[str] = (
await store.get_partial_filtered_current_state_ids(
room_id,
state_filter=StateFilter.from_types(
SLIDING_SYNC_RELEVANT_STATE_SET
),
)
)
# We're iterating over rooms that we are joined to so they should
# have `current_state_events` and we should have some current state
# for each room
assert current_state_ids_map
fetched_events = await store.get_events(current_state_ids_map.values())
current_state_map: StateMap[EventBase] = {
state_key: fetched_events[event_id]
for state_key, event_id in current_state_ids_map.items()
}
state_insert_values = (
PersistEventsStore._get_sliding_sync_insert_values_from_state_map(
current_state_map
)
)
sliding_sync_membership_snapshots_insert_map.update(state_insert_values)
# We should have some insert values for each room, even if they are `None`
assert sliding_sync_membership_snapshots_insert_map
# We have current state to work from
sliding_sync_membership_snapshots_insert_map["has_known_state"] = True
elif membership in (Membership.INVITE, Membership.KNOCK) or (
membership == Membership.LEAVE and is_outlier
):
invite_or_knock_event_id = membership_event_id
invite_or_knock_membership = membership
# If the event is an `out_of_band_membership` (special case of
# `outlier`), we never had historical state so we have to pull from
# the stripped state on the previous invite/knock event. This gives
# us a consistent view of the room state regardless of your
# membership (i.e. the room shouldn't disappear if your using the
# `is_encrypted` filter and you leave).
if membership == Membership.LEAVE and is_outlier:
invite_or_knock_event_id, invite_or_knock_membership = (
await self.db_pool.runInteraction(
"sliding_sync_membership_snapshots_backfill._find_previous_membership",
_find_previous_membership_txn,
room_id,
user_id,
membership_event_stream_ordering,
)
)
# Pull from the stripped state on the invite/knock event
invite_or_knock_event = await store.get_event(invite_or_knock_event_id)
raw_stripped_state_events = None
if invite_or_knock_membership == Membership.INVITE:
invite_room_state = invite_or_knock_event.unsigned.get(
"invite_room_state"
)
raw_stripped_state_events = invite_room_state
elif invite_or_knock_membership == Membership.KNOCK:
knock_room_state = invite_or_knock_event.unsigned.get(
"knock_room_state"
)
raw_stripped_state_events = knock_room_state
sliding_sync_membership_snapshots_insert_map = await self.db_pool.runInteraction(
"sliding_sync_membership_snapshots_backfill._get_sliding_sync_insert_values_from_stripped_state_txn",
PersistEventsStore._get_sliding_sync_insert_values_from_stripped_state_txn,
raw_stripped_state_events,
)
# We should have some insert values for each room, even if no
# stripped state is on the event because we still want to record
# that we have no known state
assert sliding_sync_membership_snapshots_insert_map
elif membership in (Membership.LEAVE, Membership.BAN):
# Pull from historical state
state_group = await store._get_state_group_for_event(
membership_event_id
)
# We should know the state for the event
assert state_group is not None
state_by_group = await self.db_pool.runInteraction(
"sliding_sync_membership_snapshots_backfill._get_state_groups_from_groups_txn",
self._get_state_groups_from_groups_txn,
groups=[state_group],
state_filter=StateFilter.from_types(
SLIDING_SYNC_RELEVANT_STATE_SET
),
)
state_ids_map = state_by_group[state_group]
fetched_events = await store.get_events(state_ids_map.values())
state_map: StateMap[EventBase] = {
state_key: fetched_events[event_id]
for state_key, event_id in state_ids_map.items()
}
state_insert_values = (
PersistEventsStore._get_sliding_sync_insert_values_from_state_map(
state_map
)
)
sliding_sync_membership_snapshots_insert_map.update(state_insert_values)
# We should have some insert values for each room, even if they are `None`
assert sliding_sync_membership_snapshots_insert_map
# We have historical state to work from
sliding_sync_membership_snapshots_insert_map["has_known_state"] = True
else:
assert_never(membership)
to_insert_membership_snapshots[(room_id, user_id)] = (
sliding_sync_membership_snapshots_insert_map
)
to_insert_membership_infos[(room_id, user_id)] = SlidingSyncMembershipInfo(
user_id=user_id,
sender=sender,
membership_event_id=membership_event_id,
)
def _backfill_table_txn(txn: LoggingTransaction) -> None:
for key, insert_map in to_insert_membership_snapshots.items():
room_id, user_id = key
membership_info = to_insert_membership_infos[key]
membership_event_id = membership_info.membership_event_id
# Pulling keys/values separately is safe and will produce congruent
# lists
insert_keys = insert_map.keys()
insert_values = insert_map.values()
# We don't need to do anything `ON CONFLICT` because we never partially
# insert/update the snapshots
txn.execute(
f"""
INSERT INTO sliding_sync_membership_snapshots
(room_id, user_id, membership_event_id, membership, event_stream_ordering
{("," + ", ".join(insert_keys)) if insert_keys else ""})
VALUES (
?, ?, ?,
(SELECT membership FROM room_memberships WHERE event_id = ?),
(SELECT stream_ordering FROM events WHERE event_id = ?)
{("," + ", ".join("?" for _ in insert_values)) if insert_values else ""}
)
ON CONFLICT (room_id, user_id)
DO NOTHING
""",
[
room_id,
user_id,
membership_event_id,
membership_event_id,
membership_event_id,
]
+ list(insert_values),
)
await self.db_pool.runInteraction(
"sliding_sync_membership_snapshots_backfill", _backfill_table_txn
)
# Update the progress
(
_room_id,
_user_id,
_sender,
_membership_event_id,
_membership,
membership_event_stream_ordering,
_is_outlier,
) = memberships_to_update_rows[-1]
await self.db_pool.updates._background_update_progress(
_BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BACKFILL,
{"last_event_stream_ordering": membership_event_stream_ordering},
)
return len(memberships_to_update_rows)

View file

@ -36,7 +36,7 @@ from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.storage.databases.main.events import DeltaState
from synapse.storage.databases.state.bg_updates import _BackgroundUpdates
from synapse.storage.databases.main.events_bg_updates import _BackgroundUpdates
from synapse.types import StateMap
from synapse.util import Clock