Rename storage classes (#12913)

This commit is contained in:
Erik Johnston 2022-05-31 13:17:50 +01:00 committed by GitHub
parent e541bb9eed
commit 1e453053cb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
53 changed files with 708 additions and 551 deletions

1
changelog.d/12913.misc Normal file
View file

@ -0,0 +1 @@
Rename storage classes.

View file

@ -22,7 +22,7 @@ from synapse.events import EventBase
from synapse.types import JsonDict, StateMap
if TYPE_CHECKING:
from synapse.storage import Storage
from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore
from synapse.storage.state import StateFilter
@ -84,7 +84,7 @@ class EventContext:
incomplete state.
"""
_storage: "Storage"
_storage: "StorageControllers"
rejected: Union[Literal[False], str] = False
_state_group: Optional[int] = None
state_group_before_event: Optional[int] = None
@ -97,7 +97,7 @@ class EventContext:
@staticmethod
def with_state(
storage: "Storage",
storage: "StorageControllers",
state_group: Optional[int],
state_group_before_event: Optional[int],
state_delta_due_to_event: Optional[StateMap[str]],
@ -117,7 +117,7 @@ class EventContext:
@staticmethod
def for_outlier(
storage: "Storage",
storage: "StorageControllers",
) -> "EventContext":
"""Return an EventContext instance suitable for persisting an outlier event"""
return EventContext(storage=storage)
@ -147,7 +147,7 @@ class EventContext:
}
@staticmethod
def deserialize(storage: "Storage", input: JsonDict) -> "EventContext":
def deserialize(storage: "StorageControllers", input: JsonDict) -> "EventContext":
"""Converts a dict that was produced by `serialize` back into a
EventContext.

View file

@ -109,7 +109,6 @@ class FederationServer(FederationBase):
super().__init__(hs)
self.handler = hs.get_federation_handler()
self.storage = hs.get_storage()
self._spam_checker = hs.get_spam_checker()
self._federation_event_handler = hs.get_federation_event_handler()
self.state = hs.get_state_handler()

View file

@ -30,8 +30,8 @@ logger = logging.getLogger(__name__)
class AdminHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.storage = hs.get_storage()
self.state_storage = self.storage.state
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
async def get_whois(self, user: UserID) -> JsonDict:
connections = []
@ -197,7 +197,9 @@ class AdminHandler:
from_key = events[-1].internal_metadata.after
events = await filter_events_for_client(self.storage, user_id, events)
events = await filter_events_for_client(
self._storage_controllers, user_id, events
)
writer.write_events(room_id, events)
@ -233,7 +235,9 @@ class AdminHandler:
for event_id in extremities:
if not event_to_unseen_prevs[event_id]:
continue
state = await self.state_storage.get_state_for_event(event_id)
state = await self._state_storage_controller.get_state_for_event(
event_id
)
writer.write_state(room_id, event_id, state)
return writer.finished()

View file

@ -71,7 +71,7 @@ class DeviceWorkerHandler:
self.store = hs.get_datastores().main
self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
self.state_storage = hs.get_storage().state
self._state_storage = hs.get_storage_controllers().state
self._auth_handler = hs.get_auth_handler()
self.server_name = hs.hostname
@ -204,7 +204,7 @@ class DeviceWorkerHandler:
continue
# mapping from event_id -> state_dict
prev_state_ids = await self.state_storage.get_state_ids_for_events(
prev_state_ids = await self._state_storage.get_state_ids_for_events(
event_ids
)

View file

@ -139,7 +139,7 @@ class EventStreamHandler:
class EventHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.storage = hs.get_storage()
self._storage_controllers = hs.get_storage_controllers()
async def get_event(
self,
@ -177,7 +177,7 @@ class EventHandler:
is_peeking = user.to_string() not in users
filtered = await filter_events_for_client(
self.storage, user.to_string(), [event], is_peeking=is_peeking
self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking
)
if not filtered:

View file

@ -125,8 +125,8 @@ class FederationHandler:
self.hs = hs
self.store = hs.get_datastores().main
self.storage = hs.get_storage()
self.state_storage = self.storage.state
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname
@ -324,7 +324,7 @@ class FederationHandler:
# We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased.
filtered_extremities = await filter_events_for_server(
self.storage,
self._storage_controllers,
self.server_name,
events_to_check,
redact=False,
@ -660,7 +660,7 @@ class FederationHandler:
# in the invitee's sync stream. It is stripped out for all other local users.
event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
context = EventContext.for_outlier(self.storage)
context = EventContext.for_outlier(self._storage_controllers)
stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
@ -849,7 +849,7 @@ class FederationHandler:
)
)
context = EventContext.for_outlier(self.storage)
context = EventContext.for_outlier(self._storage_controllers)
await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
@ -878,7 +878,7 @@ class FederationHandler:
await self.federation_client.send_leave(host_list, event)
context = EventContext.for_outlier(self.storage)
context = EventContext.for_outlier(self._storage_controllers)
stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
@ -1027,7 +1027,7 @@ class FederationHandler:
if event.internal_metadata.outlier:
raise NotFoundError("State not known at event %s" % (event_id,))
state_groups = await self.state_storage.get_state_groups_ids(
state_groups = await self._state_storage_controller.get_state_groups_ids(
room_id, [event_id]
)
@ -1078,7 +1078,9 @@ class FederationHandler:
],
)
events = await filter_events_for_server(self.storage, origin, events)
events = await filter_events_for_server(
self._storage_controllers, origin, events
)
return events
@ -1109,7 +1111,9 @@ class FederationHandler:
if not in_room:
raise AuthError(403, "Host not in room.")
events = await filter_events_for_server(self.storage, origin, [event])
events = await filter_events_for_server(
self._storage_controllers, origin, [event]
)
event = events[0]
return event
else:
@ -1138,7 +1142,7 @@ class FederationHandler:
)
missing_events = await filter_events_for_server(
self.storage, origin, missing_events
self._storage_controllers, origin, missing_events
)
return missing_events
@ -1480,9 +1484,11 @@ class FederationHandler:
# clear the lazy-loading flag.
logger.info("Updating current state for %s", room_id)
assert (
self.storage.persistence is not None
self._storage_controllers.persistence is not None
), "TODO(faster_joins): support for workers"
await self.storage.persistence.update_current_state(room_id)
await self._storage_controllers.persistence.update_current_state(
room_id
)
logger.info("Clearing partial-state flag for %s", room_id)
success = await self.store.clear_partial_state_room(room_id)

View file

@ -98,8 +98,8 @@ class FederationEventHandler:
def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastores().main
self._storage = hs.get_storage()
self._state_storage = self._storage.state
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
self._state_handler = hs.get_state_handler()
self._event_creation_handler = hs.get_event_creation_handler()
@ -535,7 +535,9 @@ class FederationEventHandler:
)
return
await self._store.update_state_for_partial_state_event(event, context)
self._state_storage.notify_event_un_partial_stated(event.event_id)
self._state_storage_controller.notify_event_un_partial_stated(
event.event_id
)
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Collection[str]
@ -835,7 +837,9 @@ class FederationEventHandler:
try:
# Get the state of the events we know about
ours = await self._state_storage.get_state_groups_ids(room_id, seen)
ours = await self._state_storage_controller.get_state_groups_ids(
room_id, seen
)
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps: List[StateMap[str]] = list(ours.values())
@ -1436,7 +1440,7 @@ class FederationEventHandler:
# we're not bothering about room state, so flag the event as an outlier.
event.internal_metadata.outlier = True
context = EventContext.for_outlier(self._storage)
context = EventContext.for_outlier(self._storage_controllers)
try:
validate_event_for_room_version(room_version_obj, event)
check_auth_rules_for_event(room_version_obj, event, auth)
@ -1613,7 +1617,7 @@ class FederationEventHandler:
# given state at the event. This should correctly handle cases
# like bans, especially with state res v2.
state_sets_d = await self._state_storage.get_state_groups_ids(
state_sets_d = await self._state_storage_controller.get_state_groups_ids(
event.room_id, extrem_ids
)
state_sets: List[StateMap[str]] = list(state_sets_d.values())
@ -1885,7 +1889,7 @@ class FederationEventHandler:
# create a new state group as a delta from the existing one.
prev_group = context.state_group
state_group = await self._state_storage.store_state_group(
state_group = await self._state_storage_controller.store_state_group(
event.event_id,
event.room_id,
prev_group=prev_group,
@ -1894,7 +1898,7 @@ class FederationEventHandler:
)
return EventContext.with_state(
storage=self._storage,
storage=self._storage_controllers,
state_group=state_group,
state_group_before_event=context.state_group_before_event,
state_delta_due_to_event=state_updates,
@ -1984,11 +1988,14 @@ class FederationEventHandler:
)
return result["max_stream_id"]
else:
assert self._storage.persistence
assert self._storage_controllers.persistence
# Note that this returns the events that were persisted, which may not be
# the same as were passed in if some were deduplicated due to transaction IDs.
events, max_stream_token = await self._storage.persistence.persist_events(
(
events,
max_stream_token,
) = await self._storage_controllers.persistence.persist_events(
event_and_contexts, backfilled=backfilled
)

View file

@ -67,8 +67,8 @@ class InitialSyncHandler:
]
] = ResponseCache(hs.get_clock(), "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_storage = self.storage.state
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
async def snapshot_all_rooms(
self,
@ -198,7 +198,8 @@ class InitialSyncHandler:
event.stream_ordering,
)
deferred_room_state = run_in_background(
self.state_storage.get_state_for_events, [event.event_id]
self._state_storage_controller.get_state_for_events,
[event.event_id],
).addCallback(
lambda states: cast(StateMap[EventBase], states[event.event_id])
)
@ -218,7 +219,7 @@ class InitialSyncHandler:
).addErrback(unwrapFirstError)
messages = await filter_events_for_client(
self.storage, user_id, messages
self._storage_controllers, user_id, messages
)
start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
@ -355,7 +356,9 @@ class InitialSyncHandler:
member_event_id: str,
is_peeking: bool,
) -> JsonDict:
room_state = await self.state_storage.get_state_for_event(member_event_id)
room_state = await self._state_storage_controller.get_state_for_event(
member_event_id
)
limit = pagin_config.limit if pagin_config else None
if limit is None:
@ -369,7 +372,7 @@ class InitialSyncHandler:
)
messages = await filter_events_for_client(
self.storage, user_id, messages, is_peeking=is_peeking
self._storage_controllers, user_id, messages, is_peeking=is_peeking
)
start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token)
@ -474,7 +477,7 @@ class InitialSyncHandler:
)
messages = await filter_events_for_client(
self.storage, user_id, messages, is_peeking=is_peeking
self._storage_controllers, user_id, messages, is_peeking=is_peeking
)
start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)

View file

@ -84,8 +84,8 @@ class MessageHandler:
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
self.store = hs.get_datastores().main
self.storage = hs.get_storage()
self.state_storage = self.storage.state
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
self._event_serializer = hs.get_event_client_serializer()
self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages
@ -132,7 +132,7 @@ class MessageHandler:
assert (
membership_event_id is not None
), "check_user_in_room_or_world_readable returned invalid data"
room_state = await self.state_storage.get_state_for_events(
room_state = await self._state_storage_controller.get_state_for_events(
[membership_event_id], StateFilter.from_types([key])
)
data = room_state[membership_event_id].get(key)
@ -193,7 +193,7 @@ class MessageHandler:
# check whether the user is in the room at that time to determine
# whether they should be treated as peeking.
state_map = await self.state_storage.get_state_for_event(
state_map = await self._state_storage_controller.get_state_for_event(
last_event.event_id,
StateFilter.from_types([(EventTypes.Member, user_id)]),
)
@ -206,7 +206,7 @@ class MessageHandler:
is_peeking = not joined
visible_events = await filter_events_for_client(
self.storage,
self._storage_controllers,
user_id,
[last_event],
filter_send_to_client=False,
@ -214,8 +214,10 @@ class MessageHandler:
)
if visible_events:
room_state_events = await self.state_storage.get_state_for_events(
[last_event.event_id], state_filter=state_filter
room_state_events = (
await self._state_storage_controller.get_state_for_events(
[last_event.event_id], state_filter=state_filter
)
)
room_state: Mapping[Any, EventBase] = room_state_events[
last_event.event_id
@ -244,8 +246,10 @@ class MessageHandler:
assert (
membership_event_id is not None
), "check_user_in_room_or_world_readable returned invalid data"
room_state_events = await self.state_storage.get_state_for_events(
[membership_event_id], state_filter=state_filter
room_state_events = (
await self._state_storage_controller.get_state_for_events(
[membership_event_id], state_filter=state_filter
)
)
room_state = room_state_events[membership_event_id]
@ -402,7 +406,7 @@ class EventCreationHandler:
self.auth = hs.get_auth()
self._event_auth_handler = hs.get_event_auth_handler()
self.store = hs.get_datastores().main
self.storage = hs.get_storage()
self._storage_controllers = hs.get_storage_controllers()
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
@ -1032,7 +1036,7 @@ class EventCreationHandler:
# after it is created
if builder.internal_metadata.outlier:
event.internal_metadata.outlier = True
context = EventContext.for_outlier(self.storage)
context = EventContext.for_outlier(self._storage_controllers)
elif (
event.type == EventTypes.MSC2716_INSERTION
and state_event_ids
@ -1445,7 +1449,7 @@ class EventCreationHandler:
"""
extra_users = extra_users or []
assert self.storage.persistence is not None
assert self._storage_controllers.persistence is not None
assert self._events_shard_config.should_handle(
self._instance_name, event.room_id
)
@ -1679,7 +1683,7 @@ class EventCreationHandler:
event,
event_pos,
max_stream_token,
) = await self.storage.persistence.persist_event(
) = await self._storage_controllers.persistence.persist_event(
event, context=context, backfilled=backfilled
)

View file

@ -129,8 +129,8 @@ class PaginationHandler:
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self.storage = hs.get_storage()
self.state_storage = self.storage.state
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
self.clock = hs.get_clock()
self._server_name = hs.hostname
self._room_shutdown_handler = hs.get_room_shutdown_handler()
@ -352,7 +352,7 @@ class PaginationHandler:
self._purges_in_progress_by_room.add(room_id)
try:
async with self.pagination_lock.write(room_id):
await self.storage.purge_events.purge_history(
await self._storage_controllers.purge_events.purge_history(
room_id, token, delete_local_events
)
logger.info("[purge] complete")
@ -414,7 +414,7 @@ class PaginationHandler:
if joined:
raise SynapseError(400, "Users are still joined to this room")
await self.storage.purge_events.purge_room(room_id)
await self._storage_controllers.purge_events.purge_room(room_id)
async def get_messages(
self,
@ -529,7 +529,10 @@ class PaginationHandler:
events = await event_filter.filter(events)
events = await filter_events_for_client(
self.storage, user_id, events, is_peeking=(member_event_id is None)
self._storage_controllers,
user_id,
events,
is_peeking=(member_event_id is None),
)
# if after the filter applied there are no more events
@ -550,7 +553,7 @@ class PaginationHandler:
(EventTypes.Member, event.sender) for event in events
)
state_ids = await self.state_storage.get_state_ids_for_event(
state_ids = await self._state_storage_controller.get_state_ids_for_event(
events[0].event_id, state_filter=state_filter
)
@ -664,7 +667,7 @@ class PaginationHandler:
400, "Users are still joined to this room"
)
await self.storage.purge_events.purge_room(room_id)
await self._storage_controllers.purge_events.purge_room(room_id)
logger.info("complete")
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_COMPLETE

View file

@ -69,7 +69,7 @@ class BundledAggregations:
class RelationsHandler:
def __init__(self, hs: "HomeServer"):
self._main_store = hs.get_datastores().main
self._storage = hs.get_storage()
self._storage_controllers = hs.get_storage_controllers()
self._auth = hs.get_auth()
self._clock = hs.get_clock()
self._event_handler = hs.get_event_handler()
@ -143,7 +143,10 @@ class RelationsHandler:
)
events = await filter_events_for_client(
self._storage, user_id, events, is_peeking=(member_event_id is None)
self._storage_controllers,
user_id,
events,
is_peeking=(member_event_id is None),
)
now = self._clock.time_msec()

View file

@ -1192,8 +1192,8 @@ class RoomContextHandler:
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self.storage = hs.get_storage()
self.state_storage = self.storage.state
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
self._relations_handler = hs.get_relations_handler()
async def get_event_context(
@ -1236,7 +1236,10 @@ class RoomContextHandler:
if use_admin_priviledge:
return events
return await filter_events_for_client(
self.storage, user.to_string(), events, is_peeking=is_peeking
self._storage_controllers,
user.to_string(),
events,
is_peeking=is_peeking,
)
event = await self.store.get_event(
@ -1293,7 +1296,7 @@ class RoomContextHandler:
# first? Shouldn't we be consistent with /sync?
# https://github.com/matrix-org/matrix-doc/issues/687
state = await self.state_storage.get_state_for_events(
state = await self._state_storage_controller.get_state_for_events(
[last_event_id], state_filter=state_filter
)

View file

@ -17,7 +17,7 @@ class RoomBatchHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastores().main
self.state_storage = hs.get_storage().state
self._state_storage_controller = hs.get_storage_controllers().state
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
@ -141,7 +141,7 @@ class RoomBatchHandler:
) = await self.store.get_max_depth_of(event_ids)
# mapping from (type, state_key) -> state_event_id
assert most_recent_event_id is not None
prev_state_map = await self.state_storage.get_state_ids_for_event(
prev_state_map = await self._state_storage_controller.get_state_ids_for_event(
most_recent_event_id
)
# List of state event ID's

View file

@ -55,8 +55,8 @@ class SearchHandler:
self.hs = hs
self._event_serializer = hs.get_event_client_serializer()
self._relations_handler = hs.get_relations_handler()
self.storage = hs.get_storage()
self.state_storage = self.storage.state
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
self.auth = hs.get_auth()
async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]:
@ -460,7 +460,7 @@ class SearchHandler:
filtered_events = await search_filter.filter([r["event"] for r in results])
events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events
self._storage_controllers, user.to_string(), filtered_events
)
events.sort(key=lambda e: -rank_map[e.event_id])
@ -559,7 +559,7 @@ class SearchHandler:
filtered_events = await search_filter.filter([r["event"] for r in results])
events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events
self._storage_controllers, user.to_string(), filtered_events
)
room_events.extend(events)
@ -644,11 +644,11 @@ class SearchHandler:
)
events_before = await filter_events_for_client(
self.storage, user.to_string(), res.events_before
self._storage_controllers, user.to_string(), res.events_before
)
events_after = await filter_events_for_client(
self.storage, user.to_string(), res.events_after
self._storage_controllers, user.to_string(), res.events_after
)
context: JsonDict = {
@ -677,7 +677,7 @@ class SearchHandler:
[(EventTypes.Member, sender) for sender in senders]
)
state = await self.state_storage.get_state_for_event(
state = await self._state_storage_controller.get_state_for_event(
last_event_id, state_filter
)

View file

@ -238,8 +238,8 @@ class SyncHandler:
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
self.storage = hs.get_storage()
self.state_storage = self.storage.state
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
# TODO: flush cache entries on subsequent sync request.
# Once we get the next /sync request (ie, one with the same access token
@ -512,7 +512,7 @@ class SyncHandler:
current_state_ids = frozenset(current_state_ids_map.values())
recents = await filter_events_for_client(
self.storage,
self._storage_controllers,
sync_config.user.to_string(),
recents,
always_include_ids=current_state_ids,
@ -580,7 +580,7 @@ class SyncHandler:
current_state_ids = frozenset(current_state_ids_map.values())
loaded_recents = await filter_events_for_client(
self.storage,
self._storage_controllers,
sync_config.user.to_string(),
loaded_recents,
always_include_ids=current_state_ids,
@ -630,7 +630,7 @@ class SyncHandler:
event: event of interest
state_filter: The state filter used to fetch state from the database.
"""
state_ids = await self.state_storage.get_state_ids_for_event(
state_ids = await self._state_storage_controller.get_state_ids_for_event(
event.event_id, state_filter=state_filter or StateFilter.all()
)
if event.is_state():
@ -710,7 +710,7 @@ class SyncHandler:
return None
last_event = last_events[-1]
state_ids = await self.state_storage.get_state_ids_for_event(
state_ids = await self._state_storage_controller.get_state_ids_for_event(
last_event.event_id,
state_filter=StateFilter.from_types(
[(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
@ -889,13 +889,15 @@ class SyncHandler:
if full_state:
if batch:
current_state_ids = (
await self.state_storage.get_state_ids_for_event(
await self._state_storage_controller.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter
)
)
state_ids = await self.state_storage.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter
state_ids = (
await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter
)
)
else:
@ -915,7 +917,7 @@ class SyncHandler:
elif batch.limited:
if batch:
state_at_timeline_start = (
await self.state_storage.get_state_ids_for_event(
await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter
)
)
@ -950,7 +952,7 @@ class SyncHandler:
if batch:
current_state_ids = (
await self.state_storage.get_state_ids_for_event(
await self._state_storage_controller.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter
)
)
@ -982,7 +984,7 @@ class SyncHandler:
# So we fish out all the member events corresponding to the
# timeline here, and then dedupe any redundant ones below.
state_ids = await self.state_storage.get_state_ids_for_event(
state_ids = await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id,
# we only want members!
state_filter=StateFilter.from_types(

View file

@ -221,7 +221,7 @@ class Notifier:
self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {}
self.hs = hs
self.storage = hs.get_storage()
self._storage_controllers = hs.get_storage_controllers()
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastores().main
self.pending_new_room_events: List[_PendingRoomEventEntry] = []
@ -623,7 +623,7 @@ class Notifier:
if name == "room":
new_events = await filter_events_for_client(
self.storage,
self._storage_controllers,
user.to_string(),
new_events,
is_peeking=is_peeking,

View file

@ -65,7 +65,7 @@ class HttpPusher(Pusher):
def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
super().__init__(hs, pusher_config)
self.storage = self.hs.get_storage()
self._storage_controllers = self.hs.get_storage_controllers()
self.app_display_name = pusher_config.app_display_name
self.device_display_name = pusher_config.device_display_name
self.pushkey_ts = pusher_config.ts
@ -343,7 +343,9 @@ class HttpPusher(Pusher):
}
return d
ctx = await push_tools.get_context_for_event(self.storage, event, self.user_id)
ctx = await push_tools.get_context_for_event(
self._storage_controllers, event, self.user_id
)
d = {
"notification": {

View file

@ -114,10 +114,10 @@ class Mailer:
self.send_email_handler = hs.get_send_email_handler()
self.store = self.hs.get_datastores().main
self.state_storage = self.hs.get_storage().state
self._state_storage_controller = self.hs.get_storage_controllers().state
self.macaroon_gen = self.hs.get_macaroon_generator()
self.state_handler = self.hs.get_state_handler()
self.storage = hs.get_storage()
self._storage_controllers = hs.get_storage_controllers()
self.app_name = app_name
self.email_subjects: EmailSubjectConfig = hs.config.email.email_subjects
@ -456,7 +456,7 @@ class Mailer:
}
the_events = await filter_events_for_client(
self.storage, user_id, results.events_before
self._storage_controllers, user_id, results.events_before
)
the_events.append(notif_event)
@ -494,7 +494,7 @@ class Mailer:
)
else:
# Attempt to check the historical state for the room.
historical_state = await self.state_storage.get_state_for_event(
historical_state = await self._state_storage_controller.get_state_for_event(
event.event_id, StateFilter.from_types((type_state_key,))
)
sender_state_event = historical_state.get(type_state_key)
@ -767,8 +767,10 @@ class Mailer:
member_event_ids.append(sender_state_event_id)
else:
# Attempt to check the historical state for the room.
historical_state = await self.state_storage.get_state_for_event(
event_id, StateFilter.from_types((type_state_key,))
historical_state = (
await self._state_storage_controller.get_state_for_event(
event_id, StateFilter.from_types((type_state_key,))
)
)
sender_state_event = historical_state.get(type_state_key)
if sender_state_event:

View file

@ -16,7 +16,7 @@ from typing import Dict
from synapse.api.constants import ReceiptTypes
from synapse.events import EventBase
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage
from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore
@ -52,7 +52,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
async def get_context_for_event(
storage: Storage, ev: EventBase, user_id: str
storage: StorageControllers, ev: EventBase, user_id: str
) -> Dict[str, str]:
ctx = {}

View file

@ -69,7 +69,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
super().__init__(hs)
self.store = hs.get_datastores().main
self.storage = hs.get_storage()
self._storage_controllers = hs.get_storage_controllers()
self.clock = hs.get_clock()
self.federation_event_handler = hs.get_federation_event_handler()
@ -133,7 +133,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
event.internal_metadata.outlier = event_payload["outlier"]
context = EventContext.deserialize(
self.storage, event_payload["context"]
self._storage_controllers, event_payload["context"]
)
event_and_contexts.append((event, context))

View file

@ -70,7 +70,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
self.event_creation_handler = hs.get_event_creation_handler()
self.store = hs.get_datastores().main
self.storage = hs.get_storage()
self._storage_controllers = hs.get_storage_controllers()
self.clock = hs.get_clock()
@staticmethod
@ -127,7 +127,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
event.internal_metadata.outlier = content["outlier"]
requester = Requester.deserialize(self.store, content["requester"])
context = EventContext.deserialize(self.storage, content["context"])
context = EventContext.deserialize(
self._storage_controllers, content["context"]
)
ratelimit = content["ratelimit"]
extra_users = [UserID.from_string(u) for u in content["extra_users"]]

View file

@ -123,7 +123,8 @@ from synapse.server_notices.worker_server_notices_sender import (
WorkerServerNoticesSender,
)
from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import Databases, Storage
from synapse.storage import Databases
from synapse.storage.controllers import StorageControllers
from synapse.streams.events import EventSources
from synapse.types import DomainSpecificString, ISynapseReactor
from synapse.util import Clock
@ -729,8 +730,8 @@ class HomeServer(metaclass=abc.ABCMeta):
return PasswordPolicyHandler(self)
@cache_in_self
def get_storage(self) -> Storage:
return Storage(self, self.get_datastores())
def get_storage_controllers(self) -> StorageControllers:
return StorageControllers(self, self.get_datastores())
@cache_in_self
def get_replication_streamer(self) -> ReplicationStreamer:

View file

@ -127,10 +127,10 @@ class StateHandler:
def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
self.state_storage = hs.get_storage().state
self._state_storage_controller = hs.get_storage_controllers().state
self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
self._storage = hs.get_storage()
self._storage_controllers = hs.get_storage_controllers()
@overload
async def get_current_state(
@ -337,12 +337,14 @@ class StateHandler:
#
if not state_group_before_event:
state_group_before_event = await self.state_storage.store_state_group(
event.event_id,
event.room_id,
prev_group=state_group_before_event_prev_group,
delta_ids=deltas_to_state_group_before_event,
current_state_ids=state_ids_before_event,
state_group_before_event = (
await self._state_storage_controller.store_state_group(
event.event_id,
event.room_id,
prev_group=state_group_before_event_prev_group,
delta_ids=deltas_to_state_group_before_event,
current_state_ids=state_ids_before_event,
)
)
# Assign the new state group to the cached state entry.
@ -359,7 +361,7 @@ class StateHandler:
if not event.is_state():
return EventContext.with_state(
storage=self._storage,
storage=self._storage_controllers,
state_group_before_event=state_group_before_event,
state_group=state_group_before_event,
state_delta_due_to_event={},
@ -382,16 +384,18 @@ class StateHandler:
state_ids_after_event[key] = event.event_id
delta_ids = {key: event.event_id}
state_group_after_event = await self.state_storage.store_state_group(
event.event_id,
event.room_id,
prev_group=state_group_before_event,
delta_ids=delta_ids,
current_state_ids=state_ids_after_event,
state_group_after_event = (
await self._state_storage_controller.store_state_group(
event.event_id,
event.room_id,
prev_group=state_group_before_event,
delta_ids=delta_ids,
current_state_ids=state_ids_after_event,
)
)
return EventContext.with_state(
storage=self._storage,
storage=self._storage_controllers,
state_group=state_group_after_event,
state_group_before_event=state_group_before_event,
state_delta_due_to_event=delta_ids,
@ -416,7 +420,9 @@ class StateHandler:
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
state_groups = await self.state_storage.get_state_group_for_events(event_ids)
state_groups = await self._state_storage_controller.get_state_group_for_events(
event_ids
)
state_group_ids = state_groups.values()
@ -424,8 +430,13 @@ class StateHandler:
state_group_ids_set = set(state_group_ids)
if len(state_group_ids_set) == 1:
(state_group_id,) = state_group_ids_set
state = await self.state_storage.get_state_for_groups(state_group_ids_set)
prev_group, delta_ids = await self.state_storage.get_state_group_delta(
state = await self._state_storage_controller.get_state_for_groups(
state_group_ids_set
)
(
prev_group,
delta_ids,
) = await self._state_storage_controller.get_state_group_delta(
state_group_id
)
return _StateCacheEntry(
@ -439,7 +450,7 @@ class StateHandler:
room_version = await self.store.get_room_version_id(room_id)
state_to_resolve = await self.state_storage.get_state_for_groups(
state_to_resolve = await self._state_storage_controller.get_state_for_groups(
state_group_ids_set
)

View file

@ -18,41 +18,20 @@ The storage layer is split up into multiple parts to allow Synapse to run
against different configurations of databases (e.g. single or multiple
databases). The `DatabasePool` class represents connections to a single physical
database. The `databases` are classes that talk directly to a `DatabasePool`
instance and have associated schemas, background updates, etc. On top of those
there are classes that provide high level interfaces that combine calls to
multiple `databases`.
instance and have associated schemas, background updates, etc.
On top of the databases are the StorageControllers, located in the
`synapse.storage.controllers` module. These classes provide high level
interfaces that combine calls to multiple `databases`. They are bundled into the
`StorageControllers` singleton for ease of use, and exposed via
`HomeServer.get_storage_controllers()`.
There are also schemas that get applied to every database, regardless of the
data stores associated with them (e.g. the schema version tables), which are
stored in `synapse.storage.schema`.
"""
from typing import TYPE_CHECKING
from synapse.storage.databases import Databases
from synapse.storage.databases.main import DataStore
from synapse.storage.persist_events import EventsPersistenceStorage
from synapse.storage.purge_events import PurgeEventsStorage
from synapse.storage.state import StateGroupStorage
if TYPE_CHECKING:
from synapse.server import HomeServer
__all__ = ["Databases", "DataStore"]
class Storage:
"""The high level interfaces for talking to various storage layers."""
def __init__(self, hs: "HomeServer", stores: Databases):
# We include the main data store here mainly so that we don't have to
# rewrite all the existing code to split it into high vs low level
# interfaces.
self.main = stores.main
self.purge_events = PurgeEventsStorage(hs, stores)
self.state = StateGroupStorage(hs, stores)
self.persistence = None
if stores.persist_events:
self.persistence = EventsPersistenceStorage(hs, stores)

View file

@ -0,0 +1,46 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from synapse.storage.controllers.persist_events import (
EventsPersistenceStorageController,
)
from synapse.storage.controllers.purge_events import PurgeEventsStorageController
from synapse.storage.controllers.state import StateGroupStorageController
from synapse.storage.databases import Databases
from synapse.storage.databases.main import DataStore
if TYPE_CHECKING:
from synapse.server import HomeServer
__all__ = ["Databases", "DataStore"]
class StorageControllers:
"""The high level interfaces for talking to various storage controller layers."""
def __init__(self, hs: "HomeServer", stores: Databases):
# We include the main data store here mainly so that we don't have to
# rewrite all the existing code to split it into high vs low level
# interfaces.
self.main = stores.main
self.purge_events = PurgeEventsStorageController(hs, stores)
self.state = StateGroupStorageController(hs, stores)
self.persistence = None
if stores.persist_events:
self.persistence = EventsPersistenceStorageController(hs, stores)

View file

@ -272,7 +272,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
pass
class EventsPersistenceStorage:
class EventsPersistenceStorageController:
"""High level interface for handling persisting newly received events.
Takes care of batching up events by room, and calculating the necessary

View file

@ -24,7 +24,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class PurgeEventsStorage:
class PurgeEventsStorageController:
"""High level interface for purging rooms and event history."""
def __init__(self, hs: "HomeServer", stores: Databases):

View file

@ -0,0 +1,351 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import (
TYPE_CHECKING,
Awaitable,
Collection,
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
)
from synapse.events import EventBase
from synapse.storage.state import StateFilter
from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
from synapse.types import MutableStateMap, StateMap
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.databases import Databases
logger = logging.getLogger(__name__)
class StateGroupStorageController:
"""High level interface to fetching state for event."""
def __init__(self, hs: "HomeServer", stores: "Databases"):
self._is_mine_id = hs.is_mine_id
self.stores = stores
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
def notify_event_un_partial_stated(self, event_id: str) -> None:
self._partial_state_events_tracker.notify_un_partial_stated(event_id)
async def get_state_group_delta(
self, state_group: int
) -> Tuple[Optional[int], Optional[StateMap[str]]]:
"""Given a state group try to return a previous group and a delta between
the old and the new.
Args:
state_group: The state group used to retrieve state deltas.
Returns:
A tuple of the previous group and a state map of the event IDs which
make up the delta between the old and new state groups.
"""
state_group_delta = await self.stores.state.get_state_group_delta(state_group)
return state_group_delta.prev_group, state_group_delta.delta_ids
async def get_state_groups_ids(
self, _room_id: str, event_ids: Collection[str]
) -> Dict[int, MutableStateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events
Args:
_room_id: id of the room for these events
event_ids: ids of the events
Returns:
dict of state_group_id -> (dict of (type, state_key) -> event id)
Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
if not event_ids:
return {}
event_to_groups = await self.get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(groups)
return group_to_state
async def get_state_ids_for_group(
self, state_group: int, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
"""Get the event IDs of all the state in the given state group
Args:
state_group: A state group for which we want to get the state IDs.
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
Returns:
Resolves to a map of (type, state_key) -> event_id
"""
group_to_state = await self.get_state_for_groups((state_group,), state_filter)
return group_to_state[state_group]
async def get_state_groups(
self, room_id: str, event_ids: Collection[str]
) -> Dict[int, List[EventBase]]:
"""Get the state groups for the given list of event_ids
Args:
room_id: ID of the room for these events.
event_ids: The event IDs to retrieve state for.
Returns:
dict of state_group_id -> list of state events.
"""
if not event_ids:
return {}
group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
state_event_map = await self.stores.main.get_events(
[
ev_id
for group_ids in group_to_ids.values()
for ev_id in group_ids.values()
],
get_prev_content=False,
)
return {
group: [
state_event_map[v]
for v in event_id_map.values()
if v in state_event_map
]
for group, event_id_map in group_to_ids.items()
}
def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter
) -> Awaitable[Dict[int, StateMap[str]]]:
"""Returns the state groups for a given set of groups, filtering on
types of state events.
Args:
groups: list of state group IDs to query
state_filter: The state filter used to fetch state
from the database.
Returns:
Dict of state group to state map.
"""
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
async def get_state_for_events(
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
) -> Dict[str, StateMap[EventBase]]:
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
Args:
event_ids: The events to fetch the state of.
state_filter: The state filter used to fetch state.
Returns:
A dict of (event_id) -> (type, state_key) -> [state_events]
Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
await_full_state = True
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
await_full_state = False
event_to_groups = await self.get_state_group_for_events(
event_ids, await_full_state=await_full_state
)
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter or StateFilter.all()
)
state_event_map = await self.stores.main.get_events(
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
get_prev_content=False,
)
event_to_state = {
event_id: {
k: state_event_map[v]
for k, v in group_to_state[group].items()
if v in state_event_map
}
for event_id, group in event_to_groups.items()
}
return {event: event_to_state[event] for event in event_ids}
async def get_state_ids_for_events(
self,
event_ids: Collection[str],
state_filter: Optional[StateFilter] = None,
) -> Dict[str, StateMap[str]]:
"""
Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves)
Args:
event_ids: events whose state should be returned
state_filter: The state filter used to fetch state from the database.
Returns:
A dict from event_id -> (type, state_key) -> event_id
Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
await_full_state = True
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
await_full_state = False
event_to_groups = await self.get_state_group_for_events(
event_ids, await_full_state=await_full_state
)
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter or StateFilter.all()
)
event_to_state = {
event_id: group_to_state[group]
for event_id, group in event_to_groups.items()
}
return {event: event_to_state[event] for event in event_ids}
async def get_state_for_event(
self, event_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[EventBase]:
"""
Get the state dict corresponding to a particular event
Args:
event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database.
Returns:
A dict from (type, state_key) -> state_event
Raises:
RuntimeError if we don't have a state group for the event (ie it is an
outlier or is unknown)
"""
state_map = await self.get_state_for_events(
[event_id], state_filter or StateFilter.all()
)
return state_map[event_id]
async def get_state_ids_for_event(
self, event_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
"""
Get the state dict corresponding to a particular event
Args:
event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database.
Returns:
A dict from (type, state_key) -> state_event_id
Raises:
RuntimeError if we don't have a state group for the event (ie it is an
outlier or is unknown)
"""
state_map = await self.get_state_ids_for_events(
[event_id], state_filter or StateFilter.all()
)
return state_map[event_id]
def get_state_for_groups(
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Awaitable[Dict[int, MutableStateMap[str]]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
Args:
groups: list of state groups for which we want to get the state.
state_filter: The state filter used to fetch state.
from the database.
Returns:
Dict of state group to state map.
"""
return self.stores.state._get_state_for_groups(
groups, state_filter or StateFilter.all()
)
async def get_state_group_for_events(
self,
event_ids: Collection[str],
await_full_state: bool = True,
) -> Mapping[str, int]:
"""Returns mapping event_id -> state_group
Args:
event_ids: events to get state groups for
await_full_state: if true, will block if we do not yet have complete
state at these events.
"""
if await_full_state:
await self._partial_state_events_tracker.await_full_state(event_ids)
return await self.stores.main._get_state_group_for_events(event_ids)
async def store_state_group(
self,
event_id: str,
room_id: str,
prev_group: Optional[int],
delta_ids: Optional[StateMap[str]],
current_state_ids: StateMap[str],
) -> int:
"""Store a new set of state, returning a newly assigned state group.
Args:
event_id: The event ID for which the state was calculated.
room_id: ID of the room for which the state was calculated.
prev_group: A previous state group for the room, optional.
delta_ids: The delta between state at `prev_group` and
`current_state_ids`, if `prev_group` was given. Same format as
`current_state_ids`.
current_state_ids: The state to store. Map of (type, state_key)
to event_id.
Returns:
The state group ID
"""
return await self.stores.state.store_state_group(
event_id, room_id, prev_group, delta_ids, current_state_ids
)

View file

@ -15,7 +15,6 @@
import logging
from typing import (
TYPE_CHECKING,
Awaitable,
Callable,
Collection,
Dict,
@ -32,15 +31,11 @@ import attr
from frozendict import frozendict
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
from synapse.types import MutableStateMap, StateKey, StateMap
if TYPE_CHECKING:
from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad
from synapse.server import HomeServer
from synapse.storage.databases import Databases
logger = logging.getLogger(__name__)
@ -578,318 +573,3 @@ _ALL_NON_MEMBER_STATE_FILTER = StateFilter(
types=frozendict({EventTypes.Member: frozenset()}), include_others=True
)
_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False)
class StateGroupStorage:
"""High level interface to fetching state for event."""
def __init__(self, hs: "HomeServer", stores: "Databases"):
self._is_mine_id = hs.is_mine_id
self.stores = stores
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
def notify_event_un_partial_stated(self, event_id: str) -> None:
self._partial_state_events_tracker.notify_un_partial_stated(event_id)
async def get_state_group_delta(
self, state_group: int
) -> Tuple[Optional[int], Optional[StateMap[str]]]:
"""Given a state group try to return a previous group and a delta between
the old and the new.
Args:
state_group: The state group used to retrieve state deltas.
Returns:
A tuple of the previous group and a state map of the event IDs which
make up the delta between the old and new state groups.
"""
state_group_delta = await self.stores.state.get_state_group_delta(state_group)
return state_group_delta.prev_group, state_group_delta.delta_ids
async def get_state_groups_ids(
self, _room_id: str, event_ids: Collection[str]
) -> Dict[int, MutableStateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events
Args:
_room_id: id of the room for these events
event_ids: ids of the events
Returns:
dict of state_group_id -> (dict of (type, state_key) -> event id)
Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
if not event_ids:
return {}
event_to_groups = await self.get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(groups)
return group_to_state
async def get_state_ids_for_group(
self, state_group: int, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
"""Get the event IDs of all the state in the given state group
Args:
state_group: A state group for which we want to get the state IDs.
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
Returns:
Resolves to a map of (type, state_key) -> event_id
"""
group_to_state = await self.get_state_for_groups((state_group,), state_filter)
return group_to_state[state_group]
async def get_state_groups(
self, room_id: str, event_ids: Collection[str]
) -> Dict[int, List[EventBase]]:
"""Get the state groups for the given list of event_ids
Args:
room_id: ID of the room for these events.
event_ids: The event IDs to retrieve state for.
Returns:
dict of state_group_id -> list of state events.
"""
if not event_ids:
return {}
group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
state_event_map = await self.stores.main.get_events(
[
ev_id
for group_ids in group_to_ids.values()
for ev_id in group_ids.values()
],
get_prev_content=False,
)
return {
group: [
state_event_map[v]
for v in event_id_map.values()
if v in state_event_map
]
for group, event_id_map in group_to_ids.items()
}
def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter
) -> Awaitable[Dict[int, StateMap[str]]]:
"""Returns the state groups for a given set of groups, filtering on
types of state events.
Args:
groups: list of state group IDs to query
state_filter: The state filter used to fetch state
from the database.
Returns:
Dict of state group to state map.
"""
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
async def get_state_for_events(
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
) -> Dict[str, StateMap[EventBase]]:
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
Args:
event_ids: The events to fetch the state of.
state_filter: The state filter used to fetch state.
Returns:
A dict of (event_id) -> (type, state_key) -> [state_events]
Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
await_full_state = True
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
await_full_state = False
event_to_groups = await self.get_state_group_for_events(
event_ids, await_full_state=await_full_state
)
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter or StateFilter.all()
)
state_event_map = await self.stores.main.get_events(
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
get_prev_content=False,
)
event_to_state = {
event_id: {
k: state_event_map[v]
for k, v in group_to_state[group].items()
if v in state_event_map
}
for event_id, group in event_to_groups.items()
}
return {event: event_to_state[event] for event in event_ids}
async def get_state_ids_for_events(
self,
event_ids: Collection[str],
state_filter: Optional[StateFilter] = None,
) -> Dict[str, StateMap[str]]:
"""
Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves)
Args:
event_ids: events whose state should be returned
state_filter: The state filter used to fetch state from the database.
Returns:
A dict from event_id -> (type, state_key) -> event_id
Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
await_full_state = True
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
await_full_state = False
event_to_groups = await self.get_state_group_for_events(
event_ids, await_full_state=await_full_state
)
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter or StateFilter.all()
)
event_to_state = {
event_id: group_to_state[group]
for event_id, group in event_to_groups.items()
}
return {event: event_to_state[event] for event in event_ids}
async def get_state_for_event(
self, event_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[EventBase]:
"""
Get the state dict corresponding to a particular event
Args:
event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database.
Returns:
A dict from (type, state_key) -> state_event
Raises:
RuntimeError if we don't have a state group for the event (ie it is an
outlier or is unknown)
"""
state_map = await self.get_state_for_events(
[event_id], state_filter or StateFilter.all()
)
return state_map[event_id]
async def get_state_ids_for_event(
self, event_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
"""
Get the state dict corresponding to a particular event
Args:
event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database.
Returns:
A dict from (type, state_key) -> state_event_id
Raises:
RuntimeError if we don't have a state group for the event (ie it is an
outlier or is unknown)
"""
state_map = await self.get_state_ids_for_events(
[event_id], state_filter or StateFilter.all()
)
return state_map[event_id]
def get_state_for_groups(
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Awaitable[Dict[int, MutableStateMap[str]]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
Args:
groups: list of state groups for which we want to get the state.
state_filter: The state filter used to fetch state.
from the database.
Returns:
Dict of state group to state map.
"""
return self.stores.state._get_state_for_groups(
groups, state_filter or StateFilter.all()
)
async def get_state_group_for_events(
self,
event_ids: Collection[str],
await_full_state: bool = True,
) -> Mapping[str, int]:
"""Returns mapping event_id -> state_group
Args:
event_ids: events to get state groups for
await_full_state: if true, will block if we do not yet have complete
state at these events.
"""
if await_full_state:
await self._partial_state_events_tracker.await_full_state(event_ids)
return await self.stores.main._get_state_group_for_events(event_ids)
async def store_state_group(
self,
event_id: str,
room_id: str,
prev_group: Optional[int],
delta_ids: Optional[StateMap[str]],
current_state_ids: StateMap[str],
) -> int:
"""Store a new set of state, returning a newly assigned state group.
Args:
event_id: The event ID for which the state was calculated.
room_id: ID of the room for which the state was calculated.
prev_group: A previous state group for the room, optional.
delta_ids: The delta between state at `prev_group` and
`current_state_ids`, if `prev_group` was given. Same format as
`current_state_ids`.
current_state_ids: The state to store. Map of (type, state_key)
to event_id.
Returns:
The state group ID
"""
return await self.stores.state.store_state_group(
event_id, room_id, prev_group, delta_ids, current_state_ids
)

View file

@ -20,7 +20,7 @@ from typing_extensions import Final
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.events import EventBase
from synapse.events.utils import prune_event
from synapse.storage import Storage
from synapse.storage.controllers import StorageControllers
from synapse.storage.state import StateFilter
from synapse.types import RetentionPolicy, StateMap, get_domain_from_id
@ -47,7 +47,7 @@ _HISTORY_VIS_KEY: Final[Tuple[str, str]] = (EventTypes.RoomHistoryVisibility, ""
async def filter_events_for_client(
storage: Storage,
storage: StorageControllers,
user_id: str,
events: List[EventBase],
is_peeking: bool = False,
@ -268,7 +268,7 @@ async def filter_events_for_client(
async def filter_events_for_server(
storage: Storage,
storage: StorageControllers,
server_name: str,
events: List[EventBase],
redact: bool = True,
@ -360,7 +360,7 @@ async def filter_events_for_server(
async def _event_to_history_vis(
storage: Storage, events: Collection[EventBase]
storage: StorageControllers, events: Collection[EventBase]
) -> Dict[str, str]:
"""Get the history visibility at each of the given events
@ -407,7 +407,7 @@ async def _event_to_history_vis(
async def _event_to_memberships(
storage: Storage, events: Collection[EventBase], server_name: str
storage: StorageControllers, events: Collection[EventBase], server_name: str
) -> Dict[str, StateMap[EventBase]]:
"""Get the remote membership list at each of the given events

View file

@ -29,7 +29,7 @@ class TestEventContext(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastores().main
self.storage = hs.get_storage()
self._storage_controllers = hs.get_storage_controllers()
self.user_id = self.register_user("u1", "pass")
self.user_tok = self.login("u1", "pass")
@ -87,7 +87,7 @@ class TestEventContext(unittest.HomeserverTestCase):
def _check_serialize_deserialize(self, event, context):
serialized = self.get_success(context.serialize(event, self.store))
d_context = EventContext.deserialize(self.storage, serialized)
d_context = EventContext.deserialize(self._storage_controllers, serialized)
self.assertEqual(context.state_group, d_context.state_group)
self.assertEqual(context.rejected, d_context.rejected)

View file

@ -50,7 +50,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
hs = self.setup_test_homeserver(federation_http_client=None)
self.handler = hs.get_federation_handler()
self.store = hs.get_datastores().main
self.state_storage = hs.get_storage().state
self.state_storage_controller = hs.get_storage_controllers().state
self._event_auth_handler = hs.get_event_auth_handler()
return hs
@ -338,7 +338,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# mapping from (type, state_key) -> state_event_id
assert most_recent_prev_event_id is not None
prev_state_map = self.get_success(
self.state_storage.get_state_ids_for_event(most_recent_prev_event_id)
self.state_storage_controller.get_state_ids_for_event(
most_recent_prev_event_id
)
)
# List of state event ID's
prev_state_ids = list(prev_state_map.values())

View file

@ -70,7 +70,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
) -> None:
OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
main_store = self.hs.get_datastores().main
state_storage = self.hs.get_storage().state
state_storage_controller = self.hs.get_storage_controllers().state
# create the room
user_id = self.register_user("kermit", "test")
@ -146,10 +146,11 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
)
if prev_exists_as_outlier:
prev_event.internal_metadata.outlier = True
persistence = self.hs.get_storage().persistence
persistence = self.hs.get_storage_controllers().persistence
self.get_success(
persistence.persist_event(
prev_event, EventContext.for_outlier(self.hs.get_storage())
prev_event,
EventContext.for_outlier(self.hs.get_storage_controllers()),
)
)
else:
@ -216,7 +217,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
# check that the state at that event is as expected
state = self.get_success(
state_storage.get_state_ids_for_event(pulled_event.event_id)
state_storage_controller.get_state_ids_for_event(pulled_event.event_id)
)
expected_state = {
(e.type, e.state_key): e.event_id for e in state_at_prev_event

View file

@ -37,7 +37,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.handler = self.hs.get_event_creation_handler()
self.persist_event_storage = self.hs.get_storage().persistence
self._persist_event_storage_controller = (
self.hs.get_storage_controllers().persistence
)
self.user_id = self.register_user("tester", "foobar")
self.access_token = self.login("tester", "foobar")
@ -65,7 +67,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
self.persist_event_storage.persist_event(memberEvent, memberEventContext)
self._persist_event_storage_controller.persist_event(
memberEvent, memberEventContext
)
)
return memberEvent, memberEventContext
@ -129,7 +133,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertNotEqual(event1.event_id, event3.event_id)
ret_event3, event_pos3, _ = self.get_success(
self.persist_event_storage.persist_event(event3, context)
self._persist_event_storage_controller.persist_event(event3, context)
)
# Assert that the returned values match those from the initial event
@ -143,7 +147,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertNotEqual(event1.event_id, event3.event_id)
events, _ = self.get_success(
self.persist_event_storage.persist_events([(event3, context)])
self._persist_event_storage_controller.persist_events([(event3, context)])
)
ret_event4 = events[0]
@ -166,7 +170,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertNotEqual(event1.event_id, event2.event_id)
events, _ = self.get_success(
self.persist_event_storage.persist_events(
self._persist_event_storage_controller.persist_events(
[(event1, context1), (event2, context2)]
)
)

View file

@ -954,7 +954,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
self.get_success(
self.hs.get_storage().persistence.persist_event(event, context)
self.hs.get_storage_controllers().persistence.persist_event(event, context)
)
def test_local_user_leaving_room_remains_in_user_directory(self) -> None:

View file

@ -32,7 +32,7 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase):
self.master_store = hs.get_datastores().main
self.slaved_store = self.worker_hs.get_datastores().main
self.storage = hs.get_storage()
self._storage_controllers = hs.get_storage_controllers()
def replicate(self):
"""Tell the master side of replication that something has happened, and then

View file

@ -262,7 +262,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
)
msg, msgctx = self.build_event()
self.get_success(
self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)])
self._storage_controllers.persistence.persist_events(
[(j2, j2ctx), (msg, msgctx)]
)
)
self.replicate()
@ -323,12 +325,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
if backfill:
self.get_success(
self.storage.persistence.persist_events(
self._storage_controllers.persistence.persist_events(
[(event, context)], backfilled=True
)
)
else:
self.get_success(self.storage.persistence.persist_event(event, context))
self.get_success(
self._storage_controllers.persistence.persist_event(event, context)
)
return event

View file

@ -31,7 +31,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
def prepare(self, reactor, clock, homeserver):
super().prepare(reactor, clock, homeserver)
self.room_creator = homeserver.get_room_creation_handler()
self.persist_event_storage = self.hs.get_storage().persistence
self.persist_event_storage_controller = (
self.hs.get_storage_controllers().persistence
)
# Create a test user
self.ourUser = UserID.from_string(OUR_USER_ID)
@ -61,7 +63,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
)
)
self.get_success(
self.persist_event_storage.persist_event(memberEvent, memberEventContext)
self.persist_event_storage_controller.persist_event(
memberEvent, memberEventContext
)
)
# Join the second user to the second room
@ -76,7 +80,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
)
)
self.get_success(
self.persist_event_storage.persist_event(memberEvent, memberEventContext)
self.persist_event_storage_controller.persist_event(
memberEvent, memberEventContext
)
)
def test_return_empty_with_no_data(self):

View file

@ -2579,7 +2579,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
other_user_tok = self.login("user", "pass")
event_builder_factory = self.hs.get_event_builder_factory()
event_creation_handler = self.hs.get_event_creation_handler()
storage = self.hs.get_storage()
storage_controllers = self.hs.get_storage_controllers()
# Create two rooms, one with a local user only and one with both a local
# and remote user.
@ -2604,7 +2604,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
event_creation_handler.create_new_client_event(builder)
)
self.get_success(storage.persistence.persist_event(event, context))
self.get_success(storage_controllers.persistence.persist_event(event, context))
# Now get rooms
url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"

View file

@ -130,7 +130,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
We do this by setting a very long time between purge jobs.
"""
store = self.hs.get_datastores().main
storage = self.hs.get_storage()
storage_controllers = self.hs.get_storage_controllers()
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
# Send a first event, which should be filtered out at the end of the test.
@ -155,7 +155,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(2, len(events), "events retrieved from database")
filtered_events = self.get_success(
filter_events_for_client(storage, self.user_id, events)
filter_events_for_client(storage_controllers, self.user_id, events)
)
# We should only get one event back.

View file

@ -88,7 +88,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.clock = clock
self.storage = hs.get_storage()
self._storage_controllers = hs.get_storage_controllers()
self.virtual_user_id, _ = self.register_appservice_user(
"as_user_potato", self.appservice.token
@ -168,7 +168,9 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
# Fetch the state_groups
state_group_map = self.get_success(
self.storage.state.get_state_groups_ids(room_id, historical_event_ids)
self._storage_controllers.state.get_state_groups_ids(
room_id, historical_event_ids
)
)
# We expect all of the historical events to be using the same state_group

View file

@ -393,7 +393,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
# We need to persist the events to the events and state_events
# tables.
persist_events_store._store_event_txn(
txn, [(e, EventContext(self.hs.get_storage())) for e in events]
txn,
[(e, EventContext(self.hs.get_storage_controllers())) for e in events],
)
# Actually call the function that calculates the auth chain stuff.

View file

@ -31,7 +31,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.state = self.hs.get_state_handler()
self.persistence = self.hs.get_storage().persistence
self._persistence = self.hs.get_storage_controllers().persistence
self.store = self.hs.get_datastores().main
self.register_user("user", "pass")
@ -71,7 +71,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
context = self.get_success(
self.state.compute_event_context(event, state_ids_before_event=state)
)
self.get_success(self.persistence.persist_event(event, context))
self.get_success(self._persistence.persist_event(event, context))
def assert_extremities(self, expected_extremities):
"""Assert the current extremities for the room"""
@ -148,7 +148,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
)
)
self.get_success(self.persistence.persist_event(remote_event_2, context))
self.get_success(self._persistence.persist_event(remote_event_2, context))
# Check that we haven't dropped the old extremity.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
@ -353,7 +353,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.state = self.hs.get_state_handler()
self.persistence = self.hs.get_storage().persistence
self._persistence = self.hs.get_storage_controllers().persistence
self.store = self.hs.get_datastores().main
def test_remote_user_rooms_cache_invalidated(self):
@ -390,7 +390,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
)
context = self.get_success(self.state.compute_event_context(remote_event_1))
self.get_success(self.persistence.persist_event(remote_event_1, context))
self.get_success(self._persistence.persist_event(remote_event_1, context))
# Call `get_rooms_for_user` to add the remote user to the cache
rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
@ -437,7 +437,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
)
context = self.get_success(self.state.compute_event_context(remote_event_1))
self.get_success(self.persistence.persist_event(remote_event_1, context))
self.get_success(self._persistence.persist_event(remote_event_1, context))
# Call `get_users_in_room` to add the remote user to the cache
users = self.get_success(self.store.get_users_in_room(room_id))

View file

@ -31,7 +31,7 @@ class PurgeTests(HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.user_id)
self.store = hs.get_datastores().main
self.storage = self.hs.get_storage()
self._storage_controllers = self.hs.get_storage_controllers()
def test_purge_history(self):
"""
@ -51,7 +51,9 @@ class PurgeTests(HomeserverTestCase):
# Purge everything before this topological token
self.get_success(
self.storage.purge_events.purge_history(self.room_id, token_str, True)
self._storage_controllers.purge_events.purge_history(
self.room_id, token_str, True
)
)
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
@ -79,7 +81,9 @@ class PurgeTests(HomeserverTestCase):
# Purge everything before this topological token
f = self.get_failure(
self.storage.purge_events.purge_history(self.room_id, event, True),
self._storage_controllers.purge_events.purge_history(
self.room_id, event, True
),
SynapseError,
)
self.assertIn("greater than forward", f.value.args[0])
@ -105,7 +109,9 @@ class PurgeTests(HomeserverTestCase):
self.assertIsNotNone(create_event)
# Purge everything before this topological token
self.get_success(self.storage.purge_events.purge_room(self.room_id))
self.get_success(
self._storage_controllers.purge_events.purge_room(self.room_id)
)
# The events aren't found.
self.store._invalidate_get_event_cache(create_event.event_id)

View file

@ -31,7 +31,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastores().main
self.storage = hs.get_storage()
self._storage = hs.get_storage_controllers()
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
@ -71,7 +71,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
self.get_success(self.storage.persistence.persist_event(event, context))
self.get_success(self._storage.persistence.persist_event(event, context))
return event
@ -93,7 +93,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
self.get_success(self.storage.persistence.persist_event(event, context))
self.get_success(self._storage.persistence.persist_event(event, context))
return event
@ -114,7 +114,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
self.get_success(self.storage.persistence.persist_event(event, context))
self.get_success(self._storage.persistence.persist_event(event, context))
return event
@ -268,7 +268,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(self.storage.persistence.persist_event(event_1, context_1))
self.get_success(self._storage.persistence.persist_event(event_1, context_1))
event_2, context_2 = self.get_success(
self.event_creation_handler.create_new_client_event(
@ -287,7 +287,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
)
self.get_success(self.storage.persistence.persist_event(event_2, context_2))
self.get_success(self._storage.persistence.persist_event(event_2, context_2))
# fetch one of the redactions
fetched = self.get_success(self.store.get_event(redaction_event_id1))
@ -411,7 +411,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
self.get_success(
self.storage.persistence.persist_event(redaction_event, context)
self._storage.persistence.persist_event(redaction_event, context)
)
# Now lets jump to the future where we have censored the redaction event

View file

@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
# Room events need the full datastore, for persist_event() and
# get_room_state()
self.store = hs.get_datastores().main
self.storage = hs.get_storage()
self._storage = hs.get_storage_controllers()
self.event_factory = hs.get_event_factory()
self.room = RoomID.from_string("!abcde:test")
@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
def inject_room_event(self, **kwargs):
self.get_success(
self.storage.persistence.persist_event(
self._storage.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
)
)

View file

@ -99,7 +99,9 @@ class EventSearchInsertionTest(HomeserverTestCase):
prev_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
prev_event = self.get_success(store.get_event(prev_event_ids[0]))
prev_state_map = self.get_success(
self.hs.get_storage().state.get_state_ids_for_event(prev_event_ids[0])
self.hs.get_storage_controllers().state.get_state_ids_for_event(
prev_event_ids[0]
)
)
event_dict = {

View file

@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class StateStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastores().main
self.storage = hs.get_storage()
self.storage = hs.get_storage_controllers()
self.state_datastore = self.storage.state.stores.state
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()

View file

@ -179,12 +179,12 @@ class Graph:
class StateTestCase(unittest.TestCase):
def setUp(self):
self.dummy_store = _DummyStore()
storage = Mock(main=self.dummy_store, state=self.dummy_store)
storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store)
hs = Mock(
spec_set=[
"config",
"get_datastores",
"get_storage",
"get_storage_controllers",
"get_auth",
"get_state_handler",
"get_clock",
@ -199,7 +199,7 @@ class StateTestCase(unittest.TestCase):
hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs)
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
hs.get_storage.return_value = storage
hs.get_storage_controllers.return_value = storage_controllers
self.state = StateHandler(hs)
self.event_id = 0

View file

@ -70,7 +70,7 @@ async def inject_event(
"""
event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
persistence = hs.get_storage().persistence
persistence = hs.get_storage_controllers().persistence
assert persistence is not None
await persistence.persist_event(event, context)

View file

@ -34,7 +34,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
super(FilterEventsForServerTestCase, self).setUp()
self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory()
self.storage = self.hs.get_storage()
self._storage_controllers = self.hs.get_storage_controllers()
self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
@ -60,7 +60,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
events_to_filter.append(evt)
filtered = self.get_success(
filter_events_for_server(self.storage, "test_server", events_to_filter)
filter_events_for_server(
self._storage_controllers, "test_server", events_to_filter
)
)
# the result should be 5 redacted events, and 5 unredacted events.
@ -80,7 +82,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
outlier = self._inject_outlier()
self.assertEqual(
self.get_success(
filter_events_for_server(self.storage, "remote_hs", [outlier])
filter_events_for_server(
self._storage_controllers, "remote_hs", [outlier]
)
),
[outlier],
)
@ -89,7 +93,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
evt = self._inject_message("@unerased:local_hs")
filtered = self.get_success(
filter_events_for_server(self.storage, "remote_hs", [outlier, evt])
filter_events_for_server(
self._storage_controllers, "remote_hs", [outlier, evt]
)
)
self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
self.assertEqual(filtered[0], outlier)
@ -99,7 +105,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# ... but other servers should only be able to see the outlier (the other should
# be redacted)
filtered = self.get_success(
filter_events_for_server(self.storage, "other_server", [outlier, evt])
filter_events_for_server(
self._storage_controllers, "other_server", [outlier, evt]
)
)
self.assertEqual(filtered[0], outlier)
self.assertEqual(filtered[1].event_id, evt.event_id)
@ -132,7 +140,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# ... and the filtering happens.
filtered = self.get_success(
filter_events_for_server(self.storage, "test_server", events_to_filter)
filter_events_for_server(
self._storage_controllers, "test_server", events_to_filter
)
)
for i in range(0, len(events_to_filter)):
@ -168,7 +178,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
event, context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
self.get_success(self.storage.persistence.persist_event(event, context))
self.get_success(
self._storage_controllers.persistence.persist_event(event, context)
)
return event
def _inject_room_member(
@ -194,7 +206,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
self.get_success(self.storage.persistence.persist_event(event, context))
self.get_success(
self._storage_controllers.persistence.persist_event(event, context)
)
return event
def _inject_message(
@ -216,7 +230,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
self.get_success(self.storage.persistence.persist_event(event, context))
self.get_success(
self._storage_controllers.persistence.persist_event(event, context)
)
return event
def _inject_outlier(self) -> EventBase:
@ -234,8 +250,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
event.internal_metadata.outlier = True
self.get_success(
self.storage.persistence.persist_event(
event, EventContext.for_outlier(self.storage)
self._storage_controllers.persistence.persist_event(
event, EventContext.for_outlier(self._storage_controllers)
)
)
return event
@ -293,7 +309,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(
self.get_success(
filter_events_for_client(
self.hs.get_storage(), "@user:test", [invite_event, reject_event]
self.hs.get_storage_controllers(),
"@user:test",
[invite_event, reject_event],
)
),
[invite_event, reject_event],
@ -303,7 +321,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(
self.get_success(
filter_events_for_client(
self.hs.get_storage(), "@other:test", [invite_event, reject_event]
self.hs.get_storage_controllers(),
"@other:test",
[invite_event, reject_event],
)
),
[],

View file

@ -264,7 +264,7 @@ class MockClock:
async def create_room(hs, room_id: str, creator_id: str):
"""Creates and persist a creation event for the given room"""
persistence_store = hs.get_storage().persistence
persistence_store = hs.get_storage_controllers().persistence
store = hs.get_datastores().main
event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler()