Rename storage classes (#12913)

This commit is contained in:
Erik Johnston 2022-05-31 13:17:50 +01:00 committed by GitHub
parent e541bb9eed
commit 1e453053cb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
53 changed files with 708 additions and 551 deletions

1
changelog.d/12913.misc Normal file
View file

@ -0,0 +1 @@
Rename storage classes.

View file

@ -22,7 +22,7 @@ from synapse.events import EventBase
from synapse.types import JsonDict, StateMap from synapse.types import JsonDict, StateMap
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.storage import Storage from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
@ -84,7 +84,7 @@ class EventContext:
incomplete state. incomplete state.
""" """
_storage: "Storage" _storage: "StorageControllers"
rejected: Union[Literal[False], str] = False rejected: Union[Literal[False], str] = False
_state_group: Optional[int] = None _state_group: Optional[int] = None
state_group_before_event: Optional[int] = None state_group_before_event: Optional[int] = None
@ -97,7 +97,7 @@ class EventContext:
@staticmethod @staticmethod
def with_state( def with_state(
storage: "Storage", storage: "StorageControllers",
state_group: Optional[int], state_group: Optional[int],
state_group_before_event: Optional[int], state_group_before_event: Optional[int],
state_delta_due_to_event: Optional[StateMap[str]], state_delta_due_to_event: Optional[StateMap[str]],
@ -117,7 +117,7 @@ class EventContext:
@staticmethod @staticmethod
def for_outlier( def for_outlier(
storage: "Storage", storage: "StorageControllers",
) -> "EventContext": ) -> "EventContext":
"""Return an EventContext instance suitable for persisting an outlier event""" """Return an EventContext instance suitable for persisting an outlier event"""
return EventContext(storage=storage) return EventContext(storage=storage)
@ -147,7 +147,7 @@ class EventContext:
} }
@staticmethod @staticmethod
def deserialize(storage: "Storage", input: JsonDict) -> "EventContext": def deserialize(storage: "StorageControllers", input: JsonDict) -> "EventContext":
"""Converts a dict that was produced by `serialize` back into a """Converts a dict that was produced by `serialize` back into a
EventContext. EventContext.

View file

@ -109,7 +109,6 @@ class FederationServer(FederationBase):
super().__init__(hs) super().__init__(hs)
self.handler = hs.get_federation_handler() self.handler = hs.get_federation_handler()
self.storage = hs.get_storage()
self._spam_checker = hs.get_spam_checker() self._spam_checker = hs.get_spam_checker()
self._federation_event_handler = hs.get_federation_event_handler() self._federation_event_handler = hs.get_federation_event_handler()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()

View file

@ -30,8 +30,8 @@ logger = logging.getLogger(__name__)
class AdminHandler: class AdminHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.storage = hs.get_storage() self._storage_controllers = hs.get_storage_controllers()
self.state_storage = self.storage.state self._state_storage_controller = self._storage_controllers.state
async def get_whois(self, user: UserID) -> JsonDict: async def get_whois(self, user: UserID) -> JsonDict:
connections = [] connections = []
@ -197,7 +197,9 @@ class AdminHandler:
from_key = events[-1].internal_metadata.after from_key = events[-1].internal_metadata.after
events = await filter_events_for_client(self.storage, user_id, events) events = await filter_events_for_client(
self._storage_controllers, user_id, events
)
writer.write_events(room_id, events) writer.write_events(room_id, events)
@ -233,7 +235,9 @@ class AdminHandler:
for event_id in extremities: for event_id in extremities:
if not event_to_unseen_prevs[event_id]: if not event_to_unseen_prevs[event_id]:
continue continue
state = await self.state_storage.get_state_for_event(event_id) state = await self._state_storage_controller.get_state_for_event(
event_id
)
writer.write_state(room_id, event_id, state) writer.write_state(room_id, event_id, state)
return writer.finished() return writer.finished()

View file

@ -71,7 +71,7 @@ class DeviceWorkerHandler:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.state_storage = hs.get_storage().state self._state_storage = hs.get_storage_controllers().state
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self.server_name = hs.hostname self.server_name = hs.hostname
@ -204,7 +204,7 @@ class DeviceWorkerHandler:
continue continue
# mapping from event_id -> state_dict # mapping from event_id -> state_dict
prev_state_ids = await self.state_storage.get_state_ids_for_events( prev_state_ids = await self._state_storage.get_state_ids_for_events(
event_ids event_ids
) )

View file

@ -139,7 +139,7 @@ class EventStreamHandler:
class EventHandler: class EventHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.storage = hs.get_storage() self._storage_controllers = hs.get_storage_controllers()
async def get_event( async def get_event(
self, self,
@ -177,7 +177,7 @@ class EventHandler:
is_peeking = user.to_string() not in users is_peeking = user.to_string() not in users
filtered = await filter_events_for_client( filtered = await filter_events_for_client(
self.storage, user.to_string(), [event], is_peeking=is_peeking self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking
) )
if not filtered: if not filtered:

View file

@ -125,8 +125,8 @@ class FederationHandler:
self.hs = hs self.hs = hs
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.storage = hs.get_storage() self._storage_controllers = hs.get_storage_controllers()
self.state_storage = self.storage.state self._state_storage_controller = self._storage_controllers.state
self.federation_client = hs.get_federation_client() self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname self.server_name = hs.hostname
@ -324,7 +324,7 @@ class FederationHandler:
# We set `check_history_visibility_only` as we might otherwise get false # We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased. # positives from users having been erased.
filtered_extremities = await filter_events_for_server( filtered_extremities = await filter_events_for_server(
self.storage, self._storage_controllers,
self.server_name, self.server_name,
events_to_check, events_to_check,
redact=False, redact=False,
@ -660,7 +660,7 @@ class FederationHandler:
# in the invitee's sync stream. It is stripped out for all other local users. # in the invitee's sync stream. It is stripped out for all other local users.
event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"] event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
context = EventContext.for_outlier(self.storage) context = EventContext.for_outlier(self._storage_controllers)
stream_id = await self._federation_event_handler.persist_events_and_notify( stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)] event.room_id, [(event, context)]
) )
@ -849,7 +849,7 @@ class FederationHandler:
) )
) )
context = EventContext.for_outlier(self.storage) context = EventContext.for_outlier(self._storage_controllers)
await self._federation_event_handler.persist_events_and_notify( await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)] event.room_id, [(event, context)]
) )
@ -878,7 +878,7 @@ class FederationHandler:
await self.federation_client.send_leave(host_list, event) await self.federation_client.send_leave(host_list, event)
context = EventContext.for_outlier(self.storage) context = EventContext.for_outlier(self._storage_controllers)
stream_id = await self._federation_event_handler.persist_events_and_notify( stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)] event.room_id, [(event, context)]
) )
@ -1027,7 +1027,7 @@ class FederationHandler:
if event.internal_metadata.outlier: if event.internal_metadata.outlier:
raise NotFoundError("State not known at event %s" % (event_id,)) raise NotFoundError("State not known at event %s" % (event_id,))
state_groups = await self.state_storage.get_state_groups_ids( state_groups = await self._state_storage_controller.get_state_groups_ids(
room_id, [event_id] room_id, [event_id]
) )
@ -1078,7 +1078,9 @@ class FederationHandler:
], ],
) )
events = await filter_events_for_server(self.storage, origin, events) events = await filter_events_for_server(
self._storage_controllers, origin, events
)
return events return events
@ -1109,7 +1111,9 @@ class FederationHandler:
if not in_room: if not in_room:
raise AuthError(403, "Host not in room.") raise AuthError(403, "Host not in room.")
events = await filter_events_for_server(self.storage, origin, [event]) events = await filter_events_for_server(
self._storage_controllers, origin, [event]
)
event = events[0] event = events[0]
return event return event
else: else:
@ -1138,7 +1142,7 @@ class FederationHandler:
) )
missing_events = await filter_events_for_server( missing_events = await filter_events_for_server(
self.storage, origin, missing_events self._storage_controllers, origin, missing_events
) )
return missing_events return missing_events
@ -1480,9 +1484,11 @@ class FederationHandler:
# clear the lazy-loading flag. # clear the lazy-loading flag.
logger.info("Updating current state for %s", room_id) logger.info("Updating current state for %s", room_id)
assert ( assert (
self.storage.persistence is not None self._storage_controllers.persistence is not None
), "TODO(faster_joins): support for workers" ), "TODO(faster_joins): support for workers"
await self.storage.persistence.update_current_state(room_id) await self._storage_controllers.persistence.update_current_state(
room_id
)
logger.info("Clearing partial-state flag for %s", room_id) logger.info("Clearing partial-state flag for %s", room_id)
success = await self.store.clear_partial_state_room(room_id) success = await self.store.clear_partial_state_room(room_id)

View file

@ -98,8 +98,8 @@ class FederationEventHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastores().main self._store = hs.get_datastores().main
self._storage = hs.get_storage() self._storage_controllers = hs.get_storage_controllers()
self._state_storage = self._storage.state self._state_storage_controller = self._storage_controllers.state
self._state_handler = hs.get_state_handler() self._state_handler = hs.get_state_handler()
self._event_creation_handler = hs.get_event_creation_handler() self._event_creation_handler = hs.get_event_creation_handler()
@ -535,7 +535,9 @@ class FederationEventHandler:
) )
return return
await self._store.update_state_for_partial_state_event(event, context) await self._store.update_state_for_partial_state_event(event, context)
self._state_storage.notify_event_un_partial_stated(event.event_id) self._state_storage_controller.notify_event_un_partial_stated(
event.event_id
)
async def backfill( async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Collection[str] self, dest: str, room_id: str, limit: int, extremities: Collection[str]
@ -835,7 +837,9 @@ class FederationEventHandler:
try: try:
# Get the state of the events we know about # Get the state of the events we know about
ours = await self._state_storage.get_state_groups_ids(room_id, seen) ours = await self._state_storage_controller.get_state_groups_ids(
room_id, seen
)
# state_maps is a list of mappings from (type, state_key) to event_id # state_maps is a list of mappings from (type, state_key) to event_id
state_maps: List[StateMap[str]] = list(ours.values()) state_maps: List[StateMap[str]] = list(ours.values())
@ -1436,7 +1440,7 @@ class FederationEventHandler:
# we're not bothering about room state, so flag the event as an outlier. # we're not bothering about room state, so flag the event as an outlier.
event.internal_metadata.outlier = True event.internal_metadata.outlier = True
context = EventContext.for_outlier(self._storage) context = EventContext.for_outlier(self._storage_controllers)
try: try:
validate_event_for_room_version(room_version_obj, event) validate_event_for_room_version(room_version_obj, event)
check_auth_rules_for_event(room_version_obj, event, auth) check_auth_rules_for_event(room_version_obj, event, auth)
@ -1613,7 +1617,7 @@ class FederationEventHandler:
# given state at the event. This should correctly handle cases # given state at the event. This should correctly handle cases
# like bans, especially with state res v2. # like bans, especially with state res v2.
state_sets_d = await self._state_storage.get_state_groups_ids( state_sets_d = await self._state_storage_controller.get_state_groups_ids(
event.room_id, extrem_ids event.room_id, extrem_ids
) )
state_sets: List[StateMap[str]] = list(state_sets_d.values()) state_sets: List[StateMap[str]] = list(state_sets_d.values())
@ -1885,7 +1889,7 @@ class FederationEventHandler:
# create a new state group as a delta from the existing one. # create a new state group as a delta from the existing one.
prev_group = context.state_group prev_group = context.state_group
state_group = await self._state_storage.store_state_group( state_group = await self._state_storage_controller.store_state_group(
event.event_id, event.event_id,
event.room_id, event.room_id,
prev_group=prev_group, prev_group=prev_group,
@ -1894,7 +1898,7 @@ class FederationEventHandler:
) )
return EventContext.with_state( return EventContext.with_state(
storage=self._storage, storage=self._storage_controllers,
state_group=state_group, state_group=state_group,
state_group_before_event=context.state_group_before_event, state_group_before_event=context.state_group_before_event,
state_delta_due_to_event=state_updates, state_delta_due_to_event=state_updates,
@ -1984,11 +1988,14 @@ class FederationEventHandler:
) )
return result["max_stream_id"] return result["max_stream_id"]
else: else:
assert self._storage.persistence assert self._storage_controllers.persistence
# Note that this returns the events that were persisted, which may not be # Note that this returns the events that were persisted, which may not be
# the same as were passed in if some were deduplicated due to transaction IDs. # the same as were passed in if some were deduplicated due to transaction IDs.
events, max_stream_token = await self._storage.persistence.persist_events( (
events,
max_stream_token,
) = await self._storage_controllers.persistence.persist_events(
event_and_contexts, backfilled=backfilled event_and_contexts, backfilled=backfilled
) )

View file

@ -67,8 +67,8 @@ class InitialSyncHandler:
] ]
] = ResponseCache(hs.get_clock(), "initial_sync_cache") ] = ResponseCache(hs.get_clock(), "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage() self._storage_controllers = hs.get_storage_controllers()
self.state_storage = self.storage.state self._state_storage_controller = self._storage_controllers.state
async def snapshot_all_rooms( async def snapshot_all_rooms(
self, self,
@ -198,7 +198,8 @@ class InitialSyncHandler:
event.stream_ordering, event.stream_ordering,
) )
deferred_room_state = run_in_background( deferred_room_state = run_in_background(
self.state_storage.get_state_for_events, [event.event_id] self._state_storage_controller.get_state_for_events,
[event.event_id],
).addCallback( ).addCallback(
lambda states: cast(StateMap[EventBase], states[event.event_id]) lambda states: cast(StateMap[EventBase], states[event.event_id])
) )
@ -218,7 +219,7 @@ class InitialSyncHandler:
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
messages = await filter_events_for_client( messages = await filter_events_for_client(
self.storage, user_id, messages self._storage_controllers, user_id, messages
) )
start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token) start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
@ -355,7 +356,9 @@ class InitialSyncHandler:
member_event_id: str, member_event_id: str,
is_peeking: bool, is_peeking: bool,
) -> JsonDict: ) -> JsonDict:
room_state = await self.state_storage.get_state_for_event(member_event_id) room_state = await self._state_storage_controller.get_state_for_event(
member_event_id
)
limit = pagin_config.limit if pagin_config else None limit = pagin_config.limit if pagin_config else None
if limit is None: if limit is None:
@ -369,7 +372,7 @@ class InitialSyncHandler:
) )
messages = await filter_events_for_client( messages = await filter_events_for_client(
self.storage, user_id, messages, is_peeking=is_peeking self._storage_controllers, user_id, messages, is_peeking=is_peeking
) )
start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token) start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token)
@ -474,7 +477,7 @@ class InitialSyncHandler:
) )
messages = await filter_events_for_client( messages = await filter_events_for_client(
self.storage, user_id, messages, is_peeking=is_peeking self._storage_controllers, user_id, messages, is_peeking=is_peeking
) )
start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token) start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)

View file

@ -84,8 +84,8 @@ class MessageHandler:
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.storage = hs.get_storage() self._storage_controllers = hs.get_storage_controllers()
self.state_storage = self.storage.state self._state_storage_controller = self._storage_controllers.state
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages
@ -132,7 +132,7 @@ class MessageHandler:
assert ( assert (
membership_event_id is not None membership_event_id is not None
), "check_user_in_room_or_world_readable returned invalid data" ), "check_user_in_room_or_world_readable returned invalid data"
room_state = await self.state_storage.get_state_for_events( room_state = await self._state_storage_controller.get_state_for_events(
[membership_event_id], StateFilter.from_types([key]) [membership_event_id], StateFilter.from_types([key])
) )
data = room_state[membership_event_id].get(key) data = room_state[membership_event_id].get(key)
@ -193,7 +193,7 @@ class MessageHandler:
# check whether the user is in the room at that time to determine # check whether the user is in the room at that time to determine
# whether they should be treated as peeking. # whether they should be treated as peeking.
state_map = await self.state_storage.get_state_for_event( state_map = await self._state_storage_controller.get_state_for_event(
last_event.event_id, last_event.event_id,
StateFilter.from_types([(EventTypes.Member, user_id)]), StateFilter.from_types([(EventTypes.Member, user_id)]),
) )
@ -206,7 +206,7 @@ class MessageHandler:
is_peeking = not joined is_peeking = not joined
visible_events = await filter_events_for_client( visible_events = await filter_events_for_client(
self.storage, self._storage_controllers,
user_id, user_id,
[last_event], [last_event],
filter_send_to_client=False, filter_send_to_client=False,
@ -214,8 +214,10 @@ class MessageHandler:
) )
if visible_events: if visible_events:
room_state_events = await self.state_storage.get_state_for_events( room_state_events = (
[last_event.event_id], state_filter=state_filter await self._state_storage_controller.get_state_for_events(
[last_event.event_id], state_filter=state_filter
)
) )
room_state: Mapping[Any, EventBase] = room_state_events[ room_state: Mapping[Any, EventBase] = room_state_events[
last_event.event_id last_event.event_id
@ -244,8 +246,10 @@ class MessageHandler:
assert ( assert (
membership_event_id is not None membership_event_id is not None
), "check_user_in_room_or_world_readable returned invalid data" ), "check_user_in_room_or_world_readable returned invalid data"
room_state_events = await self.state_storage.get_state_for_events( room_state_events = (
[membership_event_id], state_filter=state_filter await self._state_storage_controller.get_state_for_events(
[membership_event_id], state_filter=state_filter
)
) )
room_state = room_state_events[membership_event_id] room_state = room_state_events[membership_event_id]
@ -402,7 +406,7 @@ class EventCreationHandler:
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._event_auth_handler = hs.get_event_auth_handler() self._event_auth_handler = hs.get_event_auth_handler()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.storage = hs.get_storage() self._storage_controllers = hs.get_storage_controllers()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.validator = EventValidator() self.validator = EventValidator()
@ -1032,7 +1036,7 @@ class EventCreationHandler:
# after it is created # after it is created
if builder.internal_metadata.outlier: if builder.internal_metadata.outlier:
event.internal_metadata.outlier = True event.internal_metadata.outlier = True
context = EventContext.for_outlier(self.storage) context = EventContext.for_outlier(self._storage_controllers)
elif ( elif (
event.type == EventTypes.MSC2716_INSERTION event.type == EventTypes.MSC2716_INSERTION
and state_event_ids and state_event_ids
@ -1445,7 +1449,7 @@ class EventCreationHandler:
""" """
extra_users = extra_users or [] extra_users = extra_users or []
assert self.storage.persistence is not None assert self._storage_controllers.persistence is not None
assert self._events_shard_config.should_handle( assert self._events_shard_config.should_handle(
self._instance_name, event.room_id self._instance_name, event.room_id
) )
@ -1679,7 +1683,7 @@ class EventCreationHandler:
event, event,
event_pos, event_pos,
max_stream_token, max_stream_token,
) = await self.storage.persistence.persist_event( ) = await self._storage_controllers.persistence.persist_event(
event, context=context, backfilled=backfilled event, context=context, backfilled=backfilled
) )

View file

@ -129,8 +129,8 @@ class PaginationHandler:
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.storage = hs.get_storage() self._storage_controllers = hs.get_storage_controllers()
self.state_storage = self.storage.state self._state_storage_controller = self._storage_controllers.state
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._server_name = hs.hostname self._server_name = hs.hostname
self._room_shutdown_handler = hs.get_room_shutdown_handler() self._room_shutdown_handler = hs.get_room_shutdown_handler()
@ -352,7 +352,7 @@ class PaginationHandler:
self._purges_in_progress_by_room.add(room_id) self._purges_in_progress_by_room.add(room_id)
try: try:
async with self.pagination_lock.write(room_id): async with self.pagination_lock.write(room_id):
await self.storage.purge_events.purge_history( await self._storage_controllers.purge_events.purge_history(
room_id, token, delete_local_events room_id, token, delete_local_events
) )
logger.info("[purge] complete") logger.info("[purge] complete")
@ -414,7 +414,7 @@ class PaginationHandler:
if joined: if joined:
raise SynapseError(400, "Users are still joined to this room") raise SynapseError(400, "Users are still joined to this room")
await self.storage.purge_events.purge_room(room_id) await self._storage_controllers.purge_events.purge_room(room_id)
async def get_messages( async def get_messages(
self, self,
@ -529,7 +529,10 @@ class PaginationHandler:
events = await event_filter.filter(events) events = await event_filter.filter(events)
events = await filter_events_for_client( events = await filter_events_for_client(
self.storage, user_id, events, is_peeking=(member_event_id is None) self._storage_controllers,
user_id,
events,
is_peeking=(member_event_id is None),
) )
# if after the filter applied there are no more events # if after the filter applied there are no more events
@ -550,7 +553,7 @@ class PaginationHandler:
(EventTypes.Member, event.sender) for event in events (EventTypes.Member, event.sender) for event in events
) )
state_ids = await self.state_storage.get_state_ids_for_event( state_ids = await self._state_storage_controller.get_state_ids_for_event(
events[0].event_id, state_filter=state_filter events[0].event_id, state_filter=state_filter
) )
@ -664,7 +667,7 @@ class PaginationHandler:
400, "Users are still joined to this room" 400, "Users are still joined to this room"
) )
await self.storage.purge_events.purge_room(room_id) await self._storage_controllers.purge_events.purge_room(room_id)
logger.info("complete") logger.info("complete")
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_COMPLETE self._delete_by_id[delete_id].status = DeleteStatus.STATUS_COMPLETE

View file

@ -69,7 +69,7 @@ class BundledAggregations:
class RelationsHandler: class RelationsHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self._main_store = hs.get_datastores().main self._main_store = hs.get_datastores().main
self._storage = hs.get_storage() self._storage_controllers = hs.get_storage_controllers()
self._auth = hs.get_auth() self._auth = hs.get_auth()
self._clock = hs.get_clock() self._clock = hs.get_clock()
self._event_handler = hs.get_event_handler() self._event_handler = hs.get_event_handler()
@ -143,7 +143,10 @@ class RelationsHandler:
) )
events = await filter_events_for_client( events = await filter_events_for_client(
self._storage, user_id, events, is_peeking=(member_event_id is None) self._storage_controllers,
user_id,
events,
is_peeking=(member_event_id is None),
) )
now = self._clock.time_msec() now = self._clock.time_msec()

View file

@ -1192,8 +1192,8 @@ class RoomContextHandler:
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.storage = hs.get_storage() self._storage_controllers = hs.get_storage_controllers()
self.state_storage = self.storage.state self._state_storage_controller = self._storage_controllers.state
self._relations_handler = hs.get_relations_handler() self._relations_handler = hs.get_relations_handler()
async def get_event_context( async def get_event_context(
@ -1236,7 +1236,10 @@ class RoomContextHandler:
if use_admin_priviledge: if use_admin_priviledge:
return events return events
return await filter_events_for_client( return await filter_events_for_client(
self.storage, user.to_string(), events, is_peeking=is_peeking self._storage_controllers,
user.to_string(),
events,
is_peeking=is_peeking,
) )
event = await self.store.get_event( event = await self.store.get_event(
@ -1293,7 +1296,7 @@ class RoomContextHandler:
# first? Shouldn't we be consistent with /sync? # first? Shouldn't we be consistent with /sync?
# https://github.com/matrix-org/matrix-doc/issues/687 # https://github.com/matrix-org/matrix-doc/issues/687
state = await self.state_storage.get_state_for_events( state = await self._state_storage_controller.get_state_for_events(
[last_event_id], state_filter=state_filter [last_event_id], state_filter=state_filter
) )

View file

@ -17,7 +17,7 @@ class RoomBatchHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.state_storage = hs.get_storage().state self._state_storage_controller = hs.get_storage_controllers().state
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -141,7 +141,7 @@ class RoomBatchHandler:
) = await self.store.get_max_depth_of(event_ids) ) = await self.store.get_max_depth_of(event_ids)
# mapping from (type, state_key) -> state_event_id # mapping from (type, state_key) -> state_event_id
assert most_recent_event_id is not None assert most_recent_event_id is not None
prev_state_map = await self.state_storage.get_state_ids_for_event( prev_state_map = await self._state_storage_controller.get_state_ids_for_event(
most_recent_event_id most_recent_event_id
) )
# List of state event ID's # List of state event ID's

View file

@ -55,8 +55,8 @@ class SearchHandler:
self.hs = hs self.hs = hs
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self._relations_handler = hs.get_relations_handler() self._relations_handler = hs.get_relations_handler()
self.storage = hs.get_storage() self._storage_controllers = hs.get_storage_controllers()
self.state_storage = self.storage.state self._state_storage_controller = self._storage_controllers.state
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]: async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]:
@ -460,7 +460,7 @@ class SearchHandler:
filtered_events = await search_filter.filter([r["event"] for r in results]) filtered_events = await search_filter.filter([r["event"] for r in results])
events = await filter_events_for_client( events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events self._storage_controllers, user.to_string(), filtered_events
) )
events.sort(key=lambda e: -rank_map[e.event_id]) events.sort(key=lambda e: -rank_map[e.event_id])
@ -559,7 +559,7 @@ class SearchHandler:
filtered_events = await search_filter.filter([r["event"] for r in results]) filtered_events = await search_filter.filter([r["event"] for r in results])
events = await filter_events_for_client( events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events self._storage_controllers, user.to_string(), filtered_events
) )
room_events.extend(events) room_events.extend(events)
@ -644,11 +644,11 @@ class SearchHandler:
) )
events_before = await filter_events_for_client( events_before = await filter_events_for_client(
self.storage, user.to_string(), res.events_before self._storage_controllers, user.to_string(), res.events_before
) )
events_after = await filter_events_for_client( events_after = await filter_events_for_client(
self.storage, user.to_string(), res.events_after self._storage_controllers, user.to_string(), res.events_after
) )
context: JsonDict = { context: JsonDict = {
@ -677,7 +677,7 @@ class SearchHandler:
[(EventTypes.Member, sender) for sender in senders] [(EventTypes.Member, sender) for sender in senders]
) )
state = await self.state_storage.get_state_for_event( state = await self._state_storage_controller.get_state_for_event(
last_event_id, state_filter last_event_id, state_filter
) )

View file

@ -238,8 +238,8 @@ class SyncHandler:
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.storage = hs.get_storage() self._storage_controllers = hs.get_storage_controllers()
self.state_storage = self.storage.state self._state_storage_controller = self._storage_controllers.state
# TODO: flush cache entries on subsequent sync request. # TODO: flush cache entries on subsequent sync request.
# Once we get the next /sync request (ie, one with the same access token # Once we get the next /sync request (ie, one with the same access token
@ -512,7 +512,7 @@ class SyncHandler:
current_state_ids = frozenset(current_state_ids_map.values()) current_state_ids = frozenset(current_state_ids_map.values())
recents = await filter_events_for_client( recents = await filter_events_for_client(
self.storage, self._storage_controllers,
sync_config.user.to_string(), sync_config.user.to_string(),
recents, recents,
always_include_ids=current_state_ids, always_include_ids=current_state_ids,
@ -580,7 +580,7 @@ class SyncHandler:
current_state_ids = frozenset(current_state_ids_map.values()) current_state_ids = frozenset(current_state_ids_map.values())
loaded_recents = await filter_events_for_client( loaded_recents = await filter_events_for_client(
self.storage, self._storage_controllers,
sync_config.user.to_string(), sync_config.user.to_string(),
loaded_recents, loaded_recents,
always_include_ids=current_state_ids, always_include_ids=current_state_ids,
@ -630,7 +630,7 @@ class SyncHandler:
event: event of interest event: event of interest
state_filter: The state filter used to fetch state from the database. state_filter: The state filter used to fetch state from the database.
""" """
state_ids = await self.state_storage.get_state_ids_for_event( state_ids = await self._state_storage_controller.get_state_ids_for_event(
event.event_id, state_filter=state_filter or StateFilter.all() event.event_id, state_filter=state_filter or StateFilter.all()
) )
if event.is_state(): if event.is_state():
@ -710,7 +710,7 @@ class SyncHandler:
return None return None
last_event = last_events[-1] last_event = last_events[-1]
state_ids = await self.state_storage.get_state_ids_for_event( state_ids = await self._state_storage_controller.get_state_ids_for_event(
last_event.event_id, last_event.event_id,
state_filter=StateFilter.from_types( state_filter=StateFilter.from_types(
[(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
@ -889,13 +889,15 @@ class SyncHandler:
if full_state: if full_state:
if batch: if batch:
current_state_ids = ( current_state_ids = (
await self.state_storage.get_state_ids_for_event( await self._state_storage_controller.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter batch.events[-1].event_id, state_filter=state_filter
) )
) )
state_ids = await self.state_storage.get_state_ids_for_event( state_ids = (
batch.events[0].event_id, state_filter=state_filter await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter
)
) )
else: else:
@ -915,7 +917,7 @@ class SyncHandler:
elif batch.limited: elif batch.limited:
if batch: if batch:
state_at_timeline_start = ( state_at_timeline_start = (
await self.state_storage.get_state_ids_for_event( await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter batch.events[0].event_id, state_filter=state_filter
) )
) )
@ -950,7 +952,7 @@ class SyncHandler:
if batch: if batch:
current_state_ids = ( current_state_ids = (
await self.state_storage.get_state_ids_for_event( await self._state_storage_controller.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter batch.events[-1].event_id, state_filter=state_filter
) )
) )
@ -982,7 +984,7 @@ class SyncHandler:
# So we fish out all the member events corresponding to the # So we fish out all the member events corresponding to the
# timeline here, and then dedupe any redundant ones below. # timeline here, and then dedupe any redundant ones below.
state_ids = await self.state_storage.get_state_ids_for_event( state_ids = await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id, batch.events[0].event_id,
# we only want members! # we only want members!
state_filter=StateFilter.from_types( state_filter=StateFilter.from_types(

View file

@ -221,7 +221,7 @@ class Notifier:
self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {} self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {}
self.hs = hs self.hs = hs
self.storage = hs.get_storage() self._storage_controllers = hs.get_storage_controllers()
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.pending_new_room_events: List[_PendingRoomEventEntry] = [] self.pending_new_room_events: List[_PendingRoomEventEntry] = []
@ -623,7 +623,7 @@ class Notifier:
if name == "room": if name == "room":
new_events = await filter_events_for_client( new_events = await filter_events_for_client(
self.storage, self._storage_controllers,
user.to_string(), user.to_string(),
new_events, new_events,
is_peeking=is_peeking, is_peeking=is_peeking,

View file

@ -65,7 +65,7 @@ class HttpPusher(Pusher):
def __init__(self, hs: "HomeServer", pusher_config: PusherConfig): def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
super().__init__(hs, pusher_config) super().__init__(hs, pusher_config)
self.storage = self.hs.get_storage() self._storage_controllers = self.hs.get_storage_controllers()
self.app_display_name = pusher_config.app_display_name self.app_display_name = pusher_config.app_display_name
self.device_display_name = pusher_config.device_display_name self.device_display_name = pusher_config.device_display_name
self.pushkey_ts = pusher_config.ts self.pushkey_ts = pusher_config.ts
@ -343,7 +343,9 @@ class HttpPusher(Pusher):
} }
return d return d
ctx = await push_tools.get_context_for_event(self.storage, event, self.user_id) ctx = await push_tools.get_context_for_event(
self._storage_controllers, event, self.user_id
)
d = { d = {
"notification": { "notification": {

View file

@ -114,10 +114,10 @@ class Mailer:
self.send_email_handler = hs.get_send_email_handler() self.send_email_handler = hs.get_send_email_handler()
self.store = self.hs.get_datastores().main self.store = self.hs.get_datastores().main
self.state_storage = self.hs.get_storage().state self._state_storage_controller = self.hs.get_storage_controllers().state
self.macaroon_gen = self.hs.get_macaroon_generator() self.macaroon_gen = self.hs.get_macaroon_generator()
self.state_handler = self.hs.get_state_handler() self.state_handler = self.hs.get_state_handler()
self.storage = hs.get_storage() self._storage_controllers = hs.get_storage_controllers()
self.app_name = app_name self.app_name = app_name
self.email_subjects: EmailSubjectConfig = hs.config.email.email_subjects self.email_subjects: EmailSubjectConfig = hs.config.email.email_subjects
@ -456,7 +456,7 @@ class Mailer:
} }
the_events = await filter_events_for_client( the_events = await filter_events_for_client(
self.storage, user_id, results.events_before self._storage_controllers, user_id, results.events_before
) )
the_events.append(notif_event) the_events.append(notif_event)
@ -494,7 +494,7 @@ class Mailer:
) )
else: else:
# Attempt to check the historical state for the room. # Attempt to check the historical state for the room.
historical_state = await self.state_storage.get_state_for_event( historical_state = await self._state_storage_controller.get_state_for_event(
event.event_id, StateFilter.from_types((type_state_key,)) event.event_id, StateFilter.from_types((type_state_key,))
) )
sender_state_event = historical_state.get(type_state_key) sender_state_event = historical_state.get(type_state_key)
@ -767,8 +767,10 @@ class Mailer:
member_event_ids.append(sender_state_event_id) member_event_ids.append(sender_state_event_id)
else: else:
# Attempt to check the historical state for the room. # Attempt to check the historical state for the room.
historical_state = await self.state_storage.get_state_for_event( historical_state = (
event_id, StateFilter.from_types((type_state_key,)) await self._state_storage_controller.get_state_for_event(
event_id, StateFilter.from_types((type_state_key,))
)
) )
sender_state_event = historical_state.get(type_state_key) sender_state_event = historical_state.get(type_state_key)
if sender_state_event: if sender_state_event:

View file

@ -16,7 +16,7 @@ from typing import Dict
from synapse.api.constants import ReceiptTypes from synapse.api.constants import ReceiptTypes
from synapse.events import EventBase from synapse.events import EventBase
from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
@ -52,7 +52,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
async def get_context_for_event( async def get_context_for_event(
storage: Storage, ev: EventBase, user_id: str storage: StorageControllers, ev: EventBase, user_id: str
) -> Dict[str, str]: ) -> Dict[str, str]:
ctx = {} ctx = {}

View file

@ -69,7 +69,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
super().__init__(hs) super().__init__(hs)
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.storage = hs.get_storage() self._storage_controllers = hs.get_storage_controllers()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.federation_event_handler = hs.get_federation_event_handler() self.federation_event_handler = hs.get_federation_event_handler()
@ -133,7 +133,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
event.internal_metadata.outlier = event_payload["outlier"] event.internal_metadata.outlier = event_payload["outlier"]
context = EventContext.deserialize( context = EventContext.deserialize(
self.storage, event_payload["context"] self._storage_controllers, event_payload["context"]
) )
event_and_contexts.append((event, context)) event_and_contexts.append((event, context))

View file

@ -70,7 +70,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.storage = hs.get_storage() self._storage_controllers = hs.get_storage_controllers()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@staticmethod @staticmethod
@ -127,7 +127,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
event.internal_metadata.outlier = content["outlier"] event.internal_metadata.outlier = content["outlier"]
requester = Requester.deserialize(self.store, content["requester"]) requester = Requester.deserialize(self.store, content["requester"])
context = EventContext.deserialize(self.storage, content["context"]) context = EventContext.deserialize(
self._storage_controllers, content["context"]
)
ratelimit = content["ratelimit"] ratelimit = content["ratelimit"]
extra_users = [UserID.from_string(u) for u in content["extra_users"]] extra_users = [UserID.from_string(u) for u in content["extra_users"]]

View file

@ -123,7 +123,8 @@ from synapse.server_notices.worker_server_notices_sender import (
WorkerServerNoticesSender, WorkerServerNoticesSender,
) )
from synapse.state import StateHandler, StateResolutionHandler from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import Databases, Storage from synapse.storage import Databases
from synapse.storage.controllers import StorageControllers
from synapse.streams.events import EventSources from synapse.streams.events import EventSources
from synapse.types import DomainSpecificString, ISynapseReactor from synapse.types import DomainSpecificString, ISynapseReactor
from synapse.util import Clock from synapse.util import Clock
@ -729,8 +730,8 @@ class HomeServer(metaclass=abc.ABCMeta):
return PasswordPolicyHandler(self) return PasswordPolicyHandler(self)
@cache_in_self @cache_in_self
def get_storage(self) -> Storage: def get_storage_controllers(self) -> StorageControllers:
return Storage(self, self.get_datastores()) return StorageControllers(self, self.get_datastores())
@cache_in_self @cache_in_self
def get_replication_streamer(self) -> ReplicationStreamer: def get_replication_streamer(self) -> ReplicationStreamer:

View file

@ -127,10 +127,10 @@ class StateHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.state_storage = hs.get_storage().state self._state_storage_controller = hs.get_storage_controllers().state
self.hs = hs self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler() self._state_resolution_handler = hs.get_state_resolution_handler()
self._storage = hs.get_storage() self._storage_controllers = hs.get_storage_controllers()
@overload @overload
async def get_current_state( async def get_current_state(
@ -337,12 +337,14 @@ class StateHandler:
# #
if not state_group_before_event: if not state_group_before_event:
state_group_before_event = await self.state_storage.store_state_group( state_group_before_event = (
event.event_id, await self._state_storage_controller.store_state_group(
event.room_id, event.event_id,
prev_group=state_group_before_event_prev_group, event.room_id,
delta_ids=deltas_to_state_group_before_event, prev_group=state_group_before_event_prev_group,
current_state_ids=state_ids_before_event, delta_ids=deltas_to_state_group_before_event,
current_state_ids=state_ids_before_event,
)
) )
# Assign the new state group to the cached state entry. # Assign the new state group to the cached state entry.
@ -359,7 +361,7 @@ class StateHandler:
if not event.is_state(): if not event.is_state():
return EventContext.with_state( return EventContext.with_state(
storage=self._storage, storage=self._storage_controllers,
state_group_before_event=state_group_before_event, state_group_before_event=state_group_before_event,
state_group=state_group_before_event, state_group=state_group_before_event,
state_delta_due_to_event={}, state_delta_due_to_event={},
@ -382,16 +384,18 @@ class StateHandler:
state_ids_after_event[key] = event.event_id state_ids_after_event[key] = event.event_id
delta_ids = {key: event.event_id} delta_ids = {key: event.event_id}
state_group_after_event = await self.state_storage.store_state_group( state_group_after_event = (
event.event_id, await self._state_storage_controller.store_state_group(
event.room_id, event.event_id,
prev_group=state_group_before_event, event.room_id,
delta_ids=delta_ids, prev_group=state_group_before_event,
current_state_ids=state_ids_after_event, delta_ids=delta_ids,
current_state_ids=state_ids_after_event,
)
) )
return EventContext.with_state( return EventContext.with_state(
storage=self._storage, storage=self._storage_controllers,
state_group=state_group_after_event, state_group=state_group_after_event,
state_group_before_event=state_group_before_event, state_group_before_event=state_group_before_event,
state_delta_due_to_event=delta_ids, state_delta_due_to_event=delta_ids,
@ -416,7 +420,9 @@ class StateHandler:
""" """
logger.debug("resolve_state_groups event_ids %s", event_ids) logger.debug("resolve_state_groups event_ids %s", event_ids)
state_groups = await self.state_storage.get_state_group_for_events(event_ids) state_groups = await self._state_storage_controller.get_state_group_for_events(
event_ids
)
state_group_ids = state_groups.values() state_group_ids = state_groups.values()
@ -424,8 +430,13 @@ class StateHandler:
state_group_ids_set = set(state_group_ids) state_group_ids_set = set(state_group_ids)
if len(state_group_ids_set) == 1: if len(state_group_ids_set) == 1:
(state_group_id,) = state_group_ids_set (state_group_id,) = state_group_ids_set
state = await self.state_storage.get_state_for_groups(state_group_ids_set) state = await self._state_storage_controller.get_state_for_groups(
prev_group, delta_ids = await self.state_storage.get_state_group_delta( state_group_ids_set
)
(
prev_group,
delta_ids,
) = await self._state_storage_controller.get_state_group_delta(
state_group_id state_group_id
) )
return _StateCacheEntry( return _StateCacheEntry(
@ -439,7 +450,7 @@ class StateHandler:
room_version = await self.store.get_room_version_id(room_id) room_version = await self.store.get_room_version_id(room_id)
state_to_resolve = await self.state_storage.get_state_for_groups( state_to_resolve = await self._state_storage_controller.get_state_for_groups(
state_group_ids_set state_group_ids_set
) )

View file

@ -18,41 +18,20 @@ The storage layer is split up into multiple parts to allow Synapse to run
against different configurations of databases (e.g. single or multiple against different configurations of databases (e.g. single or multiple
databases). The `DatabasePool` class represents connections to a single physical databases). The `DatabasePool` class represents connections to a single physical
database. The `databases` are classes that talk directly to a `DatabasePool` database. The `databases` are classes that talk directly to a `DatabasePool`
instance and have associated schemas, background updates, etc. On top of those instance and have associated schemas, background updates, etc.
there are classes that provide high level interfaces that combine calls to
multiple `databases`. On top of the databases are the StorageControllers, located in the
`synapse.storage.controllers` module. These classes provide high level
interfaces that combine calls to multiple `databases`. They are bundled into the
`StorageControllers` singleton for ease of use, and exposed via
`HomeServer.get_storage_controllers()`.
There are also schemas that get applied to every database, regardless of the There are also schemas that get applied to every database, regardless of the
data stores associated with them (e.g. the schema version tables), which are data stores associated with them (e.g. the schema version tables), which are
stored in `synapse.storage.schema`. stored in `synapse.storage.schema`.
""" """
from typing import TYPE_CHECKING
from synapse.storage.databases import Databases from synapse.storage.databases import Databases
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.storage.persist_events import EventsPersistenceStorage
from synapse.storage.purge_events import PurgeEventsStorage
from synapse.storage.state import StateGroupStorage
if TYPE_CHECKING:
from synapse.server import HomeServer
__all__ = ["Databases", "DataStore"] __all__ = ["Databases", "DataStore"]
class Storage:
"""The high level interfaces for talking to various storage layers."""
def __init__(self, hs: "HomeServer", stores: Databases):
# We include the main data store here mainly so that we don't have to
# rewrite all the existing code to split it into high vs low level
# interfaces.
self.main = stores.main
self.purge_events = PurgeEventsStorage(hs, stores)
self.state = StateGroupStorage(hs, stores)
self.persistence = None
if stores.persist_events:
self.persistence = EventsPersistenceStorage(hs, stores)

View file

@ -0,0 +1,46 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from synapse.storage.controllers.persist_events import (
EventsPersistenceStorageController,
)
from synapse.storage.controllers.purge_events import PurgeEventsStorageController
from synapse.storage.controllers.state import StateGroupStorageController
from synapse.storage.databases import Databases
from synapse.storage.databases.main import DataStore
if TYPE_CHECKING:
from synapse.server import HomeServer
__all__ = ["Databases", "DataStore"]
class StorageControllers:
"""The high level interfaces for talking to various storage controller layers."""
def __init__(self, hs: "HomeServer", stores: Databases):
# We include the main data store here mainly so that we don't have to
# rewrite all the existing code to split it into high vs low level
# interfaces.
self.main = stores.main
self.purge_events = PurgeEventsStorageController(hs, stores)
self.state = StateGroupStorageController(hs, stores)
self.persistence = None
if stores.persist_events:
self.persistence = EventsPersistenceStorageController(hs, stores)

View file

@ -272,7 +272,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
pass pass
class EventsPersistenceStorage: class EventsPersistenceStorageController:
"""High level interface for handling persisting newly received events. """High level interface for handling persisting newly received events.
Takes care of batching up events by room, and calculating the necessary Takes care of batching up events by room, and calculating the necessary

View file

@ -24,7 +24,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PurgeEventsStorage: class PurgeEventsStorageController:
"""High level interface for purging rooms and event history.""" """High level interface for purging rooms and event history."""
def __init__(self, hs: "HomeServer", stores: Databases): def __init__(self, hs: "HomeServer", stores: Databases):

View file

@ -0,0 +1,351 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import (
TYPE_CHECKING,
Awaitable,
Collection,
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
)
from synapse.events import EventBase
from synapse.storage.state import StateFilter
from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
from synapse.types import MutableStateMap, StateMap
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.databases import Databases
logger = logging.getLogger(__name__)
class StateGroupStorageController:
"""High level interface to fetching state for event."""
def __init__(self, hs: "HomeServer", stores: "Databases"):
self._is_mine_id = hs.is_mine_id
self.stores = stores
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
def notify_event_un_partial_stated(self, event_id: str) -> None:
self._partial_state_events_tracker.notify_un_partial_stated(event_id)
async def get_state_group_delta(
self, state_group: int
) -> Tuple[Optional[int], Optional[StateMap[str]]]:
"""Given a state group try to return a previous group and a delta between
the old and the new.
Args:
state_group: The state group used to retrieve state deltas.
Returns:
A tuple of the previous group and a state map of the event IDs which
make up the delta between the old and new state groups.
"""
state_group_delta = await self.stores.state.get_state_group_delta(state_group)
return state_group_delta.prev_group, state_group_delta.delta_ids
async def get_state_groups_ids(
self, _room_id: str, event_ids: Collection[str]
) -> Dict[int, MutableStateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events
Args:
_room_id: id of the room for these events
event_ids: ids of the events
Returns:
dict of state_group_id -> (dict of (type, state_key) -> event id)
Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
if not event_ids:
return {}
event_to_groups = await self.get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(groups)
return group_to_state
async def get_state_ids_for_group(
self, state_group: int, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
"""Get the event IDs of all the state in the given state group
Args:
state_group: A state group for which we want to get the state IDs.
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
Returns:
Resolves to a map of (type, state_key) -> event_id
"""
group_to_state = await self.get_state_for_groups((state_group,), state_filter)
return group_to_state[state_group]
async def get_state_groups(
self, room_id: str, event_ids: Collection[str]
) -> Dict[int, List[EventBase]]:
"""Get the state groups for the given list of event_ids
Args:
room_id: ID of the room for these events.
event_ids: The event IDs to retrieve state for.
Returns:
dict of state_group_id -> list of state events.
"""
if not event_ids:
return {}
group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
state_event_map = await self.stores.main.get_events(
[
ev_id
for group_ids in group_to_ids.values()
for ev_id in group_ids.values()
],
get_prev_content=False,
)
return {
group: [
state_event_map[v]
for v in event_id_map.values()
if v in state_event_map
]
for group, event_id_map in group_to_ids.items()
}
def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter
) -> Awaitable[Dict[int, StateMap[str]]]:
"""Returns the state groups for a given set of groups, filtering on
types of state events.
Args:
groups: list of state group IDs to query
state_filter: The state filter used to fetch state
from the database.
Returns:
Dict of state group to state map.
"""
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
async def get_state_for_events(
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
) -> Dict[str, StateMap[EventBase]]:
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
Args:
event_ids: The events to fetch the state of.
state_filter: The state filter used to fetch state.
Returns:
A dict of (event_id) -> (type, state_key) -> [state_events]
Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
await_full_state = True
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
await_full_state = False
event_to_groups = await self.get_state_group_for_events(
event_ids, await_full_state=await_full_state
)
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter or StateFilter.all()
)
state_event_map = await self.stores.main.get_events(
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
get_prev_content=False,
)
event_to_state = {
event_id: {
k: state_event_map[v]
for k, v in group_to_state[group].items()
if v in state_event_map
}
for event_id, group in event_to_groups.items()
}
return {event: event_to_state[event] for event in event_ids}
async def get_state_ids_for_events(
self,
event_ids: Collection[str],
state_filter: Optional[StateFilter] = None,
) -> Dict[str, StateMap[str]]:
"""
Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves)
Args:
event_ids: events whose state should be returned
state_filter: The state filter used to fetch state from the database.
Returns:
A dict from event_id -> (type, state_key) -> event_id
Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
await_full_state = True
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
await_full_state = False
event_to_groups = await self.get_state_group_for_events(
event_ids, await_full_state=await_full_state
)
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter or StateFilter.all()
)
event_to_state = {
event_id: group_to_state[group]
for event_id, group in event_to_groups.items()
}
return {event: event_to_state[event] for event in event_ids}
async def get_state_for_event(
self, event_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[EventBase]:
"""
Get the state dict corresponding to a particular event
Args:
event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database.
Returns:
A dict from (type, state_key) -> state_event
Raises:
RuntimeError if we don't have a state group for the event (ie it is an
outlier or is unknown)
"""
state_map = await self.get_state_for_events(
[event_id], state_filter or StateFilter.all()
)
return state_map[event_id]
async def get_state_ids_for_event(
self, event_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
"""
Get the state dict corresponding to a particular event
Args:
event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database.
Returns:
A dict from (type, state_key) -> state_event_id
Raises:
RuntimeError if we don't have a state group for the event (ie it is an
outlier or is unknown)
"""
state_map = await self.get_state_ids_for_events(
[event_id], state_filter or StateFilter.all()
)
return state_map[event_id]
def get_state_for_groups(
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Awaitable[Dict[int, MutableStateMap[str]]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
Args:
groups: list of state groups for which we want to get the state.
state_filter: The state filter used to fetch state.
from the database.
Returns:
Dict of state group to state map.
"""
return self.stores.state._get_state_for_groups(
groups, state_filter or StateFilter.all()
)
async def get_state_group_for_events(
self,
event_ids: Collection[str],
await_full_state: bool = True,
) -> Mapping[str, int]:
"""Returns mapping event_id -> state_group
Args:
event_ids: events to get state groups for
await_full_state: if true, will block if we do not yet have complete
state at these events.
"""
if await_full_state:
await self._partial_state_events_tracker.await_full_state(event_ids)
return await self.stores.main._get_state_group_for_events(event_ids)
async def store_state_group(
self,
event_id: str,
room_id: str,
prev_group: Optional[int],
delta_ids: Optional[StateMap[str]],
current_state_ids: StateMap[str],
) -> int:
"""Store a new set of state, returning a newly assigned state group.
Args:
event_id: The event ID for which the state was calculated.
room_id: ID of the room for which the state was calculated.
prev_group: A previous state group for the room, optional.
delta_ids: The delta between state at `prev_group` and
`current_state_ids`, if `prev_group` was given. Same format as
`current_state_ids`.
current_state_ids: The state to store. Map of (type, state_key)
to event_id.
Returns:
The state group ID
"""
return await self.stores.state.store_state_group(
event_id, room_id, prev_group, delta_ids, current_state_ids
)

View file

@ -15,7 +15,6 @@
import logging import logging
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Awaitable,
Callable, Callable,
Collection, Collection,
Dict, Dict,
@ -32,15 +31,11 @@ import attr
from frozendict import frozendict from frozendict import frozendict
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
from synapse.types import MutableStateMap, StateKey, StateMap from synapse.types import MutableStateMap, StateKey, StateMap
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad
from synapse.server import HomeServer
from synapse.storage.databases import Databases
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -578,318 +573,3 @@ _ALL_NON_MEMBER_STATE_FILTER = StateFilter(
types=frozendict({EventTypes.Member: frozenset()}), include_others=True types=frozendict({EventTypes.Member: frozenset()}), include_others=True
) )
_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False) _NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False)
class StateGroupStorage:
"""High level interface to fetching state for event."""
def __init__(self, hs: "HomeServer", stores: "Databases"):
self._is_mine_id = hs.is_mine_id
self.stores = stores
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
def notify_event_un_partial_stated(self, event_id: str) -> None:
self._partial_state_events_tracker.notify_un_partial_stated(event_id)
async def get_state_group_delta(
self, state_group: int
) -> Tuple[Optional[int], Optional[StateMap[str]]]:
"""Given a state group try to return a previous group and a delta between
the old and the new.
Args:
state_group: The state group used to retrieve state deltas.
Returns:
A tuple of the previous group and a state map of the event IDs which
make up the delta between the old and new state groups.
"""
state_group_delta = await self.stores.state.get_state_group_delta(state_group)
return state_group_delta.prev_group, state_group_delta.delta_ids
async def get_state_groups_ids(
self, _room_id: str, event_ids: Collection[str]
) -> Dict[int, MutableStateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events
Args:
_room_id: id of the room for these events
event_ids: ids of the events
Returns:
dict of state_group_id -> (dict of (type, state_key) -> event id)
Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
if not event_ids:
return {}
event_to_groups = await self.get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(groups)
return group_to_state
async def get_state_ids_for_group(
self, state_group: int, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
"""Get the event IDs of all the state in the given state group
Args:
state_group: A state group for which we want to get the state IDs.
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
Returns:
Resolves to a map of (type, state_key) -> event_id
"""
group_to_state = await self.get_state_for_groups((state_group,), state_filter)
return group_to_state[state_group]
async def get_state_groups(
self, room_id: str, event_ids: Collection[str]
) -> Dict[int, List[EventBase]]:
"""Get the state groups for the given list of event_ids
Args:
room_id: ID of the room for these events.
event_ids: The event IDs to retrieve state for.
Returns:
dict of state_group_id -> list of state events.
"""
if not event_ids:
return {}
group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
state_event_map = await self.stores.main.get_events(
[
ev_id
for group_ids in group_to_ids.values()
for ev_id in group_ids.values()
],
get_prev_content=False,
)
return {
group: [
state_event_map[v]
for v in event_id_map.values()
if v in state_event_map
]
for group, event_id_map in group_to_ids.items()
}
def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter
) -> Awaitable[Dict[int, StateMap[str]]]:
"""Returns the state groups for a given set of groups, filtering on
types of state events.
Args:
groups: list of state group IDs to query
state_filter: The state filter used to fetch state
from the database.
Returns:
Dict of state group to state map.
"""
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
async def get_state_for_events(
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
) -> Dict[str, StateMap[EventBase]]:
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
Args:
event_ids: The events to fetch the state of.
state_filter: The state filter used to fetch state.
Returns:
A dict of (event_id) -> (type, state_key) -> [state_events]
Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
await_full_state = True
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
await_full_state = False
event_to_groups = await self.get_state_group_for_events(
event_ids, await_full_state=await_full_state
)
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter or StateFilter.all()
)
state_event_map = await self.stores.main.get_events(
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
get_prev_content=False,
)
event_to_state = {
event_id: {
k: state_event_map[v]
for k, v in group_to_state[group].items()
if v in state_event_map
}
for event_id, group in event_to_groups.items()
}
return {event: event_to_state[event] for event in event_ids}
async def get_state_ids_for_events(
self,
event_ids: Collection[str],
state_filter: Optional[StateFilter] = None,
) -> Dict[str, StateMap[str]]:
"""
Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves)
Args:
event_ids: events whose state should be returned
state_filter: The state filter used to fetch state from the database.
Returns:
A dict from event_id -> (type, state_key) -> event_id
Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
await_full_state = True
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
await_full_state = False
event_to_groups = await self.get_state_group_for_events(
event_ids, await_full_state=await_full_state
)
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter or StateFilter.all()
)
event_to_state = {
event_id: group_to_state[group]
for event_id, group in event_to_groups.items()
}
return {event: event_to_state[event] for event in event_ids}
async def get_state_for_event(
self, event_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[EventBase]:
"""
Get the state dict corresponding to a particular event
Args:
event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database.
Returns:
A dict from (type, state_key) -> state_event
Raises:
RuntimeError if we don't have a state group for the event (ie it is an
outlier or is unknown)
"""
state_map = await self.get_state_for_events(
[event_id], state_filter or StateFilter.all()
)
return state_map[event_id]
async def get_state_ids_for_event(
self, event_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
"""
Get the state dict corresponding to a particular event
Args:
event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database.
Returns:
A dict from (type, state_key) -> state_event_id
Raises:
RuntimeError if we don't have a state group for the event (ie it is an
outlier or is unknown)
"""
state_map = await self.get_state_ids_for_events(
[event_id], state_filter or StateFilter.all()
)
return state_map[event_id]
def get_state_for_groups(
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Awaitable[Dict[int, MutableStateMap[str]]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
Args:
groups: list of state groups for which we want to get the state.
state_filter: The state filter used to fetch state.
from the database.
Returns:
Dict of state group to state map.
"""
return self.stores.state._get_state_for_groups(
groups, state_filter or StateFilter.all()
)
async def get_state_group_for_events(
self,
event_ids: Collection[str],
await_full_state: bool = True,
) -> Mapping[str, int]:
"""Returns mapping event_id -> state_group
Args:
event_ids: events to get state groups for
await_full_state: if true, will block if we do not yet have complete
state at these events.
"""
if await_full_state:
await self._partial_state_events_tracker.await_full_state(event_ids)
return await self.stores.main._get_state_group_for_events(event_ids)
async def store_state_group(
self,
event_id: str,
room_id: str,
prev_group: Optional[int],
delta_ids: Optional[StateMap[str]],
current_state_ids: StateMap[str],
) -> int:
"""Store a new set of state, returning a newly assigned state group.
Args:
event_id: The event ID for which the state was calculated.
room_id: ID of the room for which the state was calculated.
prev_group: A previous state group for the room, optional.
delta_ids: The delta between state at `prev_group` and
`current_state_ids`, if `prev_group` was given. Same format as
`current_state_ids`.
current_state_ids: The state to store. Map of (type, state_key)
to event_id.
Returns:
The state group ID
"""
return await self.stores.state.store_state_group(
event_id, room_id, prev_group, delta_ids, current_state_ids
)

View file

@ -20,7 +20,7 @@ from typing_extensions import Final
from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.storage import Storage from synapse.storage.controllers import StorageControllers
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import RetentionPolicy, StateMap, get_domain_from_id from synapse.types import RetentionPolicy, StateMap, get_domain_from_id
@ -47,7 +47,7 @@ _HISTORY_VIS_KEY: Final[Tuple[str, str]] = (EventTypes.RoomHistoryVisibility, ""
async def filter_events_for_client( async def filter_events_for_client(
storage: Storage, storage: StorageControllers,
user_id: str, user_id: str,
events: List[EventBase], events: List[EventBase],
is_peeking: bool = False, is_peeking: bool = False,
@ -268,7 +268,7 @@ async def filter_events_for_client(
async def filter_events_for_server( async def filter_events_for_server(
storage: Storage, storage: StorageControllers,
server_name: str, server_name: str,
events: List[EventBase], events: List[EventBase],
redact: bool = True, redact: bool = True,
@ -360,7 +360,7 @@ async def filter_events_for_server(
async def _event_to_history_vis( async def _event_to_history_vis(
storage: Storage, events: Collection[EventBase] storage: StorageControllers, events: Collection[EventBase]
) -> Dict[str, str]: ) -> Dict[str, str]:
"""Get the history visibility at each of the given events """Get the history visibility at each of the given events
@ -407,7 +407,7 @@ async def _event_to_history_vis(
async def _event_to_memberships( async def _event_to_memberships(
storage: Storage, events: Collection[EventBase], server_name: str storage: StorageControllers, events: Collection[EventBase], server_name: str
) -> Dict[str, StateMap[EventBase]]: ) -> Dict[str, StateMap[EventBase]]:
"""Get the remote membership list at each of the given events """Get the remote membership list at each of the given events

View file

@ -29,7 +29,7 @@ class TestEventContext(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.storage = hs.get_storage() self._storage_controllers = hs.get_storage_controllers()
self.user_id = self.register_user("u1", "pass") self.user_id = self.register_user("u1", "pass")
self.user_tok = self.login("u1", "pass") self.user_tok = self.login("u1", "pass")
@ -87,7 +87,7 @@ class TestEventContext(unittest.HomeserverTestCase):
def _check_serialize_deserialize(self, event, context): def _check_serialize_deserialize(self, event, context):
serialized = self.get_success(context.serialize(event, self.store)) serialized = self.get_success(context.serialize(event, self.store))
d_context = EventContext.deserialize(self.storage, serialized) d_context = EventContext.deserialize(self._storage_controllers, serialized)
self.assertEqual(context.state_group, d_context.state_group) self.assertEqual(context.state_group, d_context.state_group)
self.assertEqual(context.rejected, d_context.rejected) self.assertEqual(context.rejected, d_context.rejected)

View file

@ -50,7 +50,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
hs = self.setup_test_homeserver(federation_http_client=None) hs = self.setup_test_homeserver(federation_http_client=None)
self.handler = hs.get_federation_handler() self.handler = hs.get_federation_handler()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.state_storage = hs.get_storage().state self.state_storage_controller = hs.get_storage_controllers().state
self._event_auth_handler = hs.get_event_auth_handler() self._event_auth_handler = hs.get_event_auth_handler()
return hs return hs
@ -338,7 +338,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# mapping from (type, state_key) -> state_event_id # mapping from (type, state_key) -> state_event_id
assert most_recent_prev_event_id is not None assert most_recent_prev_event_id is not None
prev_state_map = self.get_success( prev_state_map = self.get_success(
self.state_storage.get_state_ids_for_event(most_recent_prev_event_id) self.state_storage_controller.get_state_ids_for_event(
most_recent_prev_event_id
)
) )
# List of state event ID's # List of state event ID's
prev_state_ids = list(prev_state_map.values()) prev_state_ids = list(prev_state_map.values())

View file

@ -70,7 +70,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
) -> None: ) -> None:
OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}" OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
main_store = self.hs.get_datastores().main main_store = self.hs.get_datastores().main
state_storage = self.hs.get_storage().state state_storage_controller = self.hs.get_storage_controllers().state
# create the room # create the room
user_id = self.register_user("kermit", "test") user_id = self.register_user("kermit", "test")
@ -146,10 +146,11 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
) )
if prev_exists_as_outlier: if prev_exists_as_outlier:
prev_event.internal_metadata.outlier = True prev_event.internal_metadata.outlier = True
persistence = self.hs.get_storage().persistence persistence = self.hs.get_storage_controllers().persistence
self.get_success( self.get_success(
persistence.persist_event( persistence.persist_event(
prev_event, EventContext.for_outlier(self.hs.get_storage()) prev_event,
EventContext.for_outlier(self.hs.get_storage_controllers()),
) )
) )
else: else:
@ -216,7 +217,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
# check that the state at that event is as expected # check that the state at that event is as expected
state = self.get_success( state = self.get_success(
state_storage.get_state_ids_for_event(pulled_event.event_id) state_storage_controller.get_state_ids_for_event(pulled_event.event_id)
) )
expected_state = { expected_state = {
(e.type, e.state_key): e.event_id for e in state_at_prev_event (e.type, e.state_key): e.event_id for e in state_at_prev_event

View file

@ -37,7 +37,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
self.handler = self.hs.get_event_creation_handler() self.handler = self.hs.get_event_creation_handler()
self.persist_event_storage = self.hs.get_storage().persistence self._persist_event_storage_controller = (
self.hs.get_storage_controllers().persistence
)
self.user_id = self.register_user("tester", "foobar") self.user_id = self.register_user("tester", "foobar")
self.access_token = self.login("tester", "foobar") self.access_token = self.login("tester", "foobar")
@ -65,7 +67,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
) )
) )
self.get_success( self.get_success(
self.persist_event_storage.persist_event(memberEvent, memberEventContext) self._persist_event_storage_controller.persist_event(
memberEvent, memberEventContext
)
) )
return memberEvent, memberEventContext return memberEvent, memberEventContext
@ -129,7 +133,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertNotEqual(event1.event_id, event3.event_id) self.assertNotEqual(event1.event_id, event3.event_id)
ret_event3, event_pos3, _ = self.get_success( ret_event3, event_pos3, _ = self.get_success(
self.persist_event_storage.persist_event(event3, context) self._persist_event_storage_controller.persist_event(event3, context)
) )
# Assert that the returned values match those from the initial event # Assert that the returned values match those from the initial event
@ -143,7 +147,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertNotEqual(event1.event_id, event3.event_id) self.assertNotEqual(event1.event_id, event3.event_id)
events, _ = self.get_success( events, _ = self.get_success(
self.persist_event_storage.persist_events([(event3, context)]) self._persist_event_storage_controller.persist_events([(event3, context)])
) )
ret_event4 = events[0] ret_event4 = events[0]
@ -166,7 +170,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertNotEqual(event1.event_id, event2.event_id) self.assertNotEqual(event1.event_id, event2.event_id)
events, _ = self.get_success( events, _ = self.get_success(
self.persist_event_storage.persist_events( self._persist_event_storage_controller.persist_events(
[(event1, context1), (event2, context2)] [(event1, context1), (event2, context2)]
) )
) )

View file

@ -954,7 +954,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
) )
self.get_success( self.get_success(
self.hs.get_storage().persistence.persist_event(event, context) self.hs.get_storage_controllers().persistence.persist_event(event, context)
) )
def test_local_user_leaving_room_remains_in_user_directory(self) -> None: def test_local_user_leaving_room_remains_in_user_directory(self) -> None:

View file

@ -32,7 +32,7 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase):
self.master_store = hs.get_datastores().main self.master_store = hs.get_datastores().main
self.slaved_store = self.worker_hs.get_datastores().main self.slaved_store = self.worker_hs.get_datastores().main
self.storage = hs.get_storage() self._storage_controllers = hs.get_storage_controllers()
def replicate(self): def replicate(self):
"""Tell the master side of replication that something has happened, and then """Tell the master side of replication that something has happened, and then

View file

@ -262,7 +262,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
) )
msg, msgctx = self.build_event() msg, msgctx = self.build_event()
self.get_success( self.get_success(
self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)]) self._storage_controllers.persistence.persist_events(
[(j2, j2ctx), (msg, msgctx)]
)
) )
self.replicate() self.replicate()
@ -323,12 +325,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
if backfill: if backfill:
self.get_success( self.get_success(
self.storage.persistence.persist_events( self._storage_controllers.persistence.persist_events(
[(event, context)], backfilled=True [(event, context)], backfilled=True
) )
) )
else: else:
self.get_success(self.storage.persistence.persist_event(event, context)) self.get_success(
self._storage_controllers.persistence.persist_event(event, context)
)
return event return event

View file

@ -31,7 +31,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor, clock, homeserver):
super().prepare(reactor, clock, homeserver) super().prepare(reactor, clock, homeserver)
self.room_creator = homeserver.get_room_creation_handler() self.room_creator = homeserver.get_room_creation_handler()
self.persist_event_storage = self.hs.get_storage().persistence self.persist_event_storage_controller = (
self.hs.get_storage_controllers().persistence
)
# Create a test user # Create a test user
self.ourUser = UserID.from_string(OUR_USER_ID) self.ourUser = UserID.from_string(OUR_USER_ID)
@ -61,7 +63,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
) )
) )
self.get_success( self.get_success(
self.persist_event_storage.persist_event(memberEvent, memberEventContext) self.persist_event_storage_controller.persist_event(
memberEvent, memberEventContext
)
) )
# Join the second user to the second room # Join the second user to the second room
@ -76,7 +80,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
) )
) )
self.get_success( self.get_success(
self.persist_event_storage.persist_event(memberEvent, memberEventContext) self.persist_event_storage_controller.persist_event(
memberEvent, memberEventContext
)
) )
def test_return_empty_with_no_data(self): def test_return_empty_with_no_data(self):

View file

@ -2579,7 +2579,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
other_user_tok = self.login("user", "pass") other_user_tok = self.login("user", "pass")
event_builder_factory = self.hs.get_event_builder_factory() event_builder_factory = self.hs.get_event_builder_factory()
event_creation_handler = self.hs.get_event_creation_handler() event_creation_handler = self.hs.get_event_creation_handler()
storage = self.hs.get_storage() storage_controllers = self.hs.get_storage_controllers()
# Create two rooms, one with a local user only and one with both a local # Create two rooms, one with a local user only and one with both a local
# and remote user. # and remote user.
@ -2604,7 +2604,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
event_creation_handler.create_new_client_event(builder) event_creation_handler.create_new_client_event(builder)
) )
self.get_success(storage.persistence.persist_event(event, context)) self.get_success(storage_controllers.persistence.persist_event(event, context))
# Now get rooms # Now get rooms
url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms" url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"

View file

@ -130,7 +130,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
We do this by setting a very long time between purge jobs. We do this by setting a very long time between purge jobs.
""" """
store = self.hs.get_datastores().main store = self.hs.get_datastores().main
storage = self.hs.get_storage() storage_controllers = self.hs.get_storage_controllers()
room_id = self.helper.create_room_as(self.user_id, tok=self.token) room_id = self.helper.create_room_as(self.user_id, tok=self.token)
# Send a first event, which should be filtered out at the end of the test. # Send a first event, which should be filtered out at the end of the test.
@ -155,7 +155,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(2, len(events), "events retrieved from database") self.assertEqual(2, len(events), "events retrieved from database")
filtered_events = self.get_success( filtered_events = self.get_success(
filter_events_for_client(storage, self.user_id, events) filter_events_for_client(storage_controllers, self.user_id, events)
) )
# We should only get one event back. # We should only get one event back.

View file

@ -88,7 +88,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.clock = clock self.clock = clock
self.storage = hs.get_storage() self._storage_controllers = hs.get_storage_controllers()
self.virtual_user_id, _ = self.register_appservice_user( self.virtual_user_id, _ = self.register_appservice_user(
"as_user_potato", self.appservice.token "as_user_potato", self.appservice.token
@ -168,7 +168,9 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
# Fetch the state_groups # Fetch the state_groups
state_group_map = self.get_success( state_group_map = self.get_success(
self.storage.state.get_state_groups_ids(room_id, historical_event_ids) self._storage_controllers.state.get_state_groups_ids(
room_id, historical_event_ids
)
) )
# We expect all of the historical events to be using the same state_group # We expect all of the historical events to be using the same state_group

View file

@ -393,7 +393,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
# We need to persist the events to the events and state_events # We need to persist the events to the events and state_events
# tables. # tables.
persist_events_store._store_event_txn( persist_events_store._store_event_txn(
txn, [(e, EventContext(self.hs.get_storage())) for e in events] txn,
[(e, EventContext(self.hs.get_storage_controllers())) for e in events],
) )
# Actually call the function that calculates the auth chain stuff. # Actually call the function that calculates the auth chain stuff.

View file

@ -31,7 +31,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor, clock, homeserver):
self.state = self.hs.get_state_handler() self.state = self.hs.get_state_handler()
self.persistence = self.hs.get_storage().persistence self._persistence = self.hs.get_storage_controllers().persistence
self.store = self.hs.get_datastores().main self.store = self.hs.get_datastores().main
self.register_user("user", "pass") self.register_user("user", "pass")
@ -71,7 +71,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
context = self.get_success( context = self.get_success(
self.state.compute_event_context(event, state_ids_before_event=state) self.state.compute_event_context(event, state_ids_before_event=state)
) )
self.get_success(self.persistence.persist_event(event, context)) self.get_success(self._persistence.persist_event(event, context))
def assert_extremities(self, expected_extremities): def assert_extremities(self, expected_extremities):
"""Assert the current extremities for the room""" """Assert the current extremities for the room"""
@ -148,7 +148,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
) )
) )
self.get_success(self.persistence.persist_event(remote_event_2, context)) self.get_success(self._persistence.persist_event(remote_event_2, context))
# Check that we haven't dropped the old extremity. # Check that we haven't dropped the old extremity.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
@ -353,7 +353,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor, clock, homeserver):
self.state = self.hs.get_state_handler() self.state = self.hs.get_state_handler()
self.persistence = self.hs.get_storage().persistence self._persistence = self.hs.get_storage_controllers().persistence
self.store = self.hs.get_datastores().main self.store = self.hs.get_datastores().main
def test_remote_user_rooms_cache_invalidated(self): def test_remote_user_rooms_cache_invalidated(self):
@ -390,7 +390,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
) )
context = self.get_success(self.state.compute_event_context(remote_event_1)) context = self.get_success(self.state.compute_event_context(remote_event_1))
self.get_success(self.persistence.persist_event(remote_event_1, context)) self.get_success(self._persistence.persist_event(remote_event_1, context))
# Call `get_rooms_for_user` to add the remote user to the cache # Call `get_rooms_for_user` to add the remote user to the cache
rooms = self.get_success(self.store.get_rooms_for_user(remote_user)) rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
@ -437,7 +437,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
) )
context = self.get_success(self.state.compute_event_context(remote_event_1)) context = self.get_success(self.state.compute_event_context(remote_event_1))
self.get_success(self.persistence.persist_event(remote_event_1, context)) self.get_success(self._persistence.persist_event(remote_event_1, context))
# Call `get_users_in_room` to add the remote user to the cache # Call `get_users_in_room` to add the remote user to the cache
users = self.get_success(self.store.get_users_in_room(room_id)) users = self.get_success(self.store.get_users_in_room(room_id))

View file

@ -31,7 +31,7 @@ class PurgeTests(HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.user_id) self.room_id = self.helper.create_room_as(self.user_id)
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.storage = self.hs.get_storage() self._storage_controllers = self.hs.get_storage_controllers()
def test_purge_history(self): def test_purge_history(self):
""" """
@ -51,7 +51,9 @@ class PurgeTests(HomeserverTestCase):
# Purge everything before this topological token # Purge everything before this topological token
self.get_success( self.get_success(
self.storage.purge_events.purge_history(self.room_id, token_str, True) self._storage_controllers.purge_events.purge_history(
self.room_id, token_str, True
)
) )
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted # 1-3 should fail and last will succeed, meaning that 1-3 are deleted
@ -79,7 +81,9 @@ class PurgeTests(HomeserverTestCase):
# Purge everything before this topological token # Purge everything before this topological token
f = self.get_failure( f = self.get_failure(
self.storage.purge_events.purge_history(self.room_id, event, True), self._storage_controllers.purge_events.purge_history(
self.room_id, event, True
),
SynapseError, SynapseError,
) )
self.assertIn("greater than forward", f.value.args[0]) self.assertIn("greater than forward", f.value.args[0])
@ -105,7 +109,9 @@ class PurgeTests(HomeserverTestCase):
self.assertIsNotNone(create_event) self.assertIsNotNone(create_event)
# Purge everything before this topological token # Purge everything before this topological token
self.get_success(self.storage.purge_events.purge_room(self.room_id)) self.get_success(
self._storage_controllers.purge_events.purge_room(self.room_id)
)
# The events aren't found. # The events aren't found.
self.store._invalidate_get_event_cache(create_event.event_id) self.store._invalidate_get_event_cache(create_event.event_id)

View file

@ -31,7 +31,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.storage = hs.get_storage() self._storage = hs.get_storage_controllers()
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
@ -71,7 +71,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
self.get_success(self.storage.persistence.persist_event(event, context)) self.get_success(self._storage.persistence.persist_event(event, context))
return event return event
@ -93,7 +93,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
self.get_success(self.storage.persistence.persist_event(event, context)) self.get_success(self._storage.persistence.persist_event(event, context))
return event return event
@ -114,7 +114,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
self.get_success(self.storage.persistence.persist_event(event, context)) self.get_success(self._storage.persistence.persist_event(event, context))
return event return event
@ -268,7 +268,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
) )
) )
self.get_success(self.storage.persistence.persist_event(event_1, context_1)) self.get_success(self._storage.persistence.persist_event(event_1, context_1))
event_2, context_2 = self.get_success( event_2, context_2 = self.get_success(
self.event_creation_handler.create_new_client_event( self.event_creation_handler.create_new_client_event(
@ -287,7 +287,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
) )
) )
) )
self.get_success(self.storage.persistence.persist_event(event_2, context_2)) self.get_success(self._storage.persistence.persist_event(event_2, context_2))
# fetch one of the redactions # fetch one of the redactions
fetched = self.get_success(self.store.get_event(redaction_event_id1)) fetched = self.get_success(self.store.get_event(redaction_event_id1))
@ -411,7 +411,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
) )
self.get_success( self.get_success(
self.storage.persistence.persist_event(redaction_event, context) self._storage.persistence.persist_event(redaction_event, context)
) )
# Now lets jump to the future where we have censored the redaction event # Now lets jump to the future where we have censored the redaction event

View file

@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
# Room events need the full datastore, for persist_event() and # Room events need the full datastore, for persist_event() and
# get_room_state() # get_room_state()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.storage = hs.get_storage() self._storage = hs.get_storage_controllers()
self.event_factory = hs.get_event_factory() self.event_factory = hs.get_event_factory()
self.room = RoomID.from_string("!abcde:test") self.room = RoomID.from_string("!abcde:test")
@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
def inject_room_event(self, **kwargs): def inject_room_event(self, **kwargs):
self.get_success( self.get_success(
self.storage.persistence.persist_event( self._storage.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
) )
) )

View file

@ -99,7 +99,9 @@ class EventSearchInsertionTest(HomeserverTestCase):
prev_event_ids = self.get_success(store.get_prev_events_for_room(room_id)) prev_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
prev_event = self.get_success(store.get_event(prev_event_ids[0])) prev_event = self.get_success(store.get_event(prev_event_ids[0]))
prev_state_map = self.get_success( prev_state_map = self.get_success(
self.hs.get_storage().state.get_state_ids_for_event(prev_event_ids[0]) self.hs.get_storage_controllers().state.get_state_ids_for_event(
prev_event_ids[0]
)
) )
event_dict = { event_dict = {

View file

@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class StateStoreTestCase(HomeserverTestCase): class StateStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.storage = hs.get_storage() self.storage = hs.get_storage_controllers()
self.state_datastore = self.storage.state.stores.state self.state_datastore = self.storage.state.stores.state
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()

View file

@ -179,12 +179,12 @@ class Graph:
class StateTestCase(unittest.TestCase): class StateTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.dummy_store = _DummyStore() self.dummy_store = _DummyStore()
storage = Mock(main=self.dummy_store, state=self.dummy_store) storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store)
hs = Mock( hs = Mock(
spec_set=[ spec_set=[
"config", "config",
"get_datastores", "get_datastores",
"get_storage", "get_storage_controllers",
"get_auth", "get_auth",
"get_state_handler", "get_state_handler",
"get_clock", "get_clock",
@ -199,7 +199,7 @@ class StateTestCase(unittest.TestCase):
hs.get_clock.return_value = MockClock() hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs) hs.get_auth.return_value = Auth(hs)
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
hs.get_storage.return_value = storage hs.get_storage_controllers.return_value = storage_controllers
self.state = StateHandler(hs) self.state = StateHandler(hs)
self.event_id = 0 self.event_id = 0

View file

@ -70,7 +70,7 @@ async def inject_event(
""" """
event, context = await create_event(hs, room_version, prev_event_ids, **kwargs) event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
persistence = hs.get_storage().persistence persistence = hs.get_storage_controllers().persistence
assert persistence is not None assert persistence is not None
await persistence.persist_event(event, context) await persistence.persist_event(event, context)

View file

@ -34,7 +34,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
super(FilterEventsForServerTestCase, self).setUp() super(FilterEventsForServerTestCase, self).setUp()
self.event_creation_handler = self.hs.get_event_creation_handler() self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory() self.event_builder_factory = self.hs.get_event_builder_factory()
self.storage = self.hs.get_storage() self._storage_controllers = self.hs.get_storage_controllers()
self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
@ -60,7 +60,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
events_to_filter.append(evt) events_to_filter.append(evt)
filtered = self.get_success( filtered = self.get_success(
filter_events_for_server(self.storage, "test_server", events_to_filter) filter_events_for_server(
self._storage_controllers, "test_server", events_to_filter
)
) )
# the result should be 5 redacted events, and 5 unredacted events. # the result should be 5 redacted events, and 5 unredacted events.
@ -80,7 +82,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
outlier = self._inject_outlier() outlier = self._inject_outlier()
self.assertEqual( self.assertEqual(
self.get_success( self.get_success(
filter_events_for_server(self.storage, "remote_hs", [outlier]) filter_events_for_server(
self._storage_controllers, "remote_hs", [outlier]
)
), ),
[outlier], [outlier],
) )
@ -89,7 +93,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
evt = self._inject_message("@unerased:local_hs") evt = self._inject_message("@unerased:local_hs")
filtered = self.get_success( filtered = self.get_success(
filter_events_for_server(self.storage, "remote_hs", [outlier, evt]) filter_events_for_server(
self._storage_controllers, "remote_hs", [outlier, evt]
)
) )
self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}") self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
self.assertEqual(filtered[0], outlier) self.assertEqual(filtered[0], outlier)
@ -99,7 +105,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# ... but other servers should only be able to see the outlier (the other should # ... but other servers should only be able to see the outlier (the other should
# be redacted) # be redacted)
filtered = self.get_success( filtered = self.get_success(
filter_events_for_server(self.storage, "other_server", [outlier, evt]) filter_events_for_server(
self._storage_controllers, "other_server", [outlier, evt]
)
) )
self.assertEqual(filtered[0], outlier) self.assertEqual(filtered[0], outlier)
self.assertEqual(filtered[1].event_id, evt.event_id) self.assertEqual(filtered[1].event_id, evt.event_id)
@ -132,7 +140,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# ... and the filtering happens. # ... and the filtering happens.
filtered = self.get_success( filtered = self.get_success(
filter_events_for_server(self.storage, "test_server", events_to_filter) filter_events_for_server(
self._storage_controllers, "test_server", events_to_filter
)
) )
for i in range(0, len(events_to_filter)): for i in range(0, len(events_to_filter)):
@ -168,7 +178,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
event, context = self.get_success( event, context = self.get_success(
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
self.get_success(self.storage.persistence.persist_event(event, context)) self.get_success(
self._storage_controllers.persistence.persist_event(event, context)
)
return event return event
def _inject_room_member( def _inject_room_member(
@ -194,7 +206,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
self.get_success(self.storage.persistence.persist_event(event, context)) self.get_success(
self._storage_controllers.persistence.persist_event(event, context)
)
return event return event
def _inject_message( def _inject_message(
@ -216,7 +230,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
self.get_success(self.storage.persistence.persist_event(event, context)) self.get_success(
self._storage_controllers.persistence.persist_event(event, context)
)
return event return event
def _inject_outlier(self) -> EventBase: def _inject_outlier(self) -> EventBase:
@ -234,8 +250,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[])) event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
event.internal_metadata.outlier = True event.internal_metadata.outlier = True
self.get_success( self.get_success(
self.storage.persistence.persist_event( self._storage_controllers.persistence.persist_event(
event, EventContext.for_outlier(self.storage) event, EventContext.for_outlier(self._storage_controllers)
) )
) )
return event return event
@ -293,7 +309,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual( self.assertEqual(
self.get_success( self.get_success(
filter_events_for_client( filter_events_for_client(
self.hs.get_storage(), "@user:test", [invite_event, reject_event] self.hs.get_storage_controllers(),
"@user:test",
[invite_event, reject_event],
) )
), ),
[invite_event, reject_event], [invite_event, reject_event],
@ -303,7 +321,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual( self.assertEqual(
self.get_success( self.get_success(
filter_events_for_client( filter_events_for_client(
self.hs.get_storage(), "@other:test", [invite_event, reject_event] self.hs.get_storage_controllers(),
"@other:test",
[invite_event, reject_event],
) )
), ),
[], [],

View file

@ -264,7 +264,7 @@ class MockClock:
async def create_room(hs, room_id: str, creator_id: str): async def create_room(hs, room_id: str, creator_id: str):
"""Creates and persist a creation event for the given room""" """Creates and persist a creation event for the given room"""
persistence_store = hs.get_storage().persistence persistence_store = hs.get_storage_controllers().persistence
store = hs.get_datastores().main store = hs.get_datastores().main
event_builder_factory = hs.get_event_builder_factory() event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler() event_creation_handler = hs.get_event_creation_handler()