mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-29 07:28:55 +03:00
Wait for lazy join to complete when getting current state (#12872)
This commit is contained in:
parent
782cb7420a
commit
888a29f412
33 changed files with 361 additions and 82 deletions
1
changelog.d/12872.misc
Normal file
1
changelog.d/12872.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Faster room joins: when querying the current state of the room, wait for state to be populated.
|
|
@ -152,6 +152,7 @@ class ThirdPartyEventRules:
|
||||||
self.third_party_rules = None
|
self.third_party_rules = None
|
||||||
|
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
|
|
||||||
self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = []
|
self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = []
|
||||||
self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = []
|
self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = []
|
||||||
|
@ -463,7 +464,7 @@ class ThirdPartyEventRules:
|
||||||
Returns:
|
Returns:
|
||||||
A dict mapping (event type, state key) to state event.
|
A dict mapping (event type, state key) to state event.
|
||||||
"""
|
"""
|
||||||
state_ids = await self.store.get_filtered_current_state_ids(room_id)
|
state_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
|
||||||
room_state_events = await self.store.get_events(state_ids.values())
|
room_state_events = await self.store.get_events(state_ids.values())
|
||||||
|
|
||||||
state_events = {}
|
state_events = {}
|
||||||
|
|
|
@ -118,6 +118,8 @@ class FederationServer(FederationBase):
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self._event_auth_handler = hs.get_event_auth_handler()
|
self._event_auth_handler = hs.get_event_auth_handler()
|
||||||
|
|
||||||
|
self._state_storage_controller = hs.get_storage_controllers().state
|
||||||
|
|
||||||
self.device_handler = hs.get_device_handler()
|
self.device_handler = hs.get_device_handler()
|
||||||
|
|
||||||
# Ensure the following handlers are loaded since they register callbacks
|
# Ensure the following handlers are loaded since they register callbacks
|
||||||
|
@ -1221,7 +1223,7 @@ class FederationServer(FederationBase):
|
||||||
Raises:
|
Raises:
|
||||||
AuthError if the server does not match the ACL
|
AuthError if the server does not match the ACL
|
||||||
"""
|
"""
|
||||||
state_ids = await self.store.get_current_state_ids(room_id)
|
state_ids = await self._state_storage_controller.get_current_state_ids(room_id)
|
||||||
acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
|
acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
|
||||||
|
|
||||||
if not acl_event_id:
|
if not acl_event_id:
|
||||||
|
|
|
@ -166,7 +166,7 @@ class DeviceWorkerHandler:
|
||||||
possibly_changed = set(changed)
|
possibly_changed = set(changed)
|
||||||
possibly_left = set()
|
possibly_left = set()
|
||||||
for room_id in rooms_changed:
|
for room_id in rooms_changed:
|
||||||
current_state_ids = await self.store.get_current_state_ids(room_id)
|
current_state_ids = await self._state_storage.get_current_state_ids(room_id)
|
||||||
|
|
||||||
# The user may have left the room
|
# The user may have left the room
|
||||||
# TODO: Check if they actually did or if we were just invited.
|
# TODO: Check if they actually did or if we were just invited.
|
||||||
|
|
|
@ -45,6 +45,7 @@ class DirectoryHandler:
|
||||||
self.appservice_handler = hs.get_application_service_handler()
|
self.appservice_handler = hs.get_application_service_handler()
|
||||||
self.event_creation_handler = hs.get_event_creation_handler()
|
self.event_creation_handler = hs.get_event_creation_handler()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.config = hs.config
|
self.config = hs.config
|
||||||
self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
|
self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
|
||||||
self.require_membership = hs.config.server.require_membership_for_aliases
|
self.require_membership = hs.config.server.require_membership_for_aliases
|
||||||
|
@ -463,7 +464,11 @@ class DirectoryHandler:
|
||||||
making_public = visibility == "public"
|
making_public = visibility == "public"
|
||||||
if making_public:
|
if making_public:
|
||||||
room_aliases = await self.store.get_aliases_for_room(room_id)
|
room_aliases = await self.store.get_aliases_for_room(room_id)
|
||||||
canonical_alias = await self.store.get_canonical_alias_for_room(room_id)
|
canonical_alias = (
|
||||||
|
await self._storage_controllers.state.get_canonical_alias_for_room(
|
||||||
|
room_id
|
||||||
|
)
|
||||||
|
)
|
||||||
if canonical_alias:
|
if canonical_alias:
|
||||||
room_aliases.append(canonical_alias)
|
room_aliases.append(canonical_alias)
|
||||||
|
|
||||||
|
|
|
@ -750,7 +750,9 @@ class FederationHandler:
|
||||||
# Note that this requires the /send_join request to come back to the
|
# Note that this requires the /send_join request to come back to the
|
||||||
# same server.
|
# same server.
|
||||||
if room_version.msc3083_join_rules:
|
if room_version.msc3083_join_rules:
|
||||||
state_ids = await self.store.get_current_state_ids(room_id)
|
state_ids = await self._state_storage_controller.get_current_state_ids(
|
||||||
|
room_id
|
||||||
|
)
|
||||||
if await self._event_auth_handler.has_restricted_join_rules(
|
if await self._event_auth_handler.has_restricted_join_rules(
|
||||||
state_ids, room_version
|
state_ids, room_version
|
||||||
):
|
):
|
||||||
|
@ -1552,6 +1554,9 @@ class FederationHandler:
|
||||||
success = await self.store.clear_partial_state_room(room_id)
|
success = await self.store.clear_partial_state_room(room_id)
|
||||||
if success:
|
if success:
|
||||||
logger.info("State resync complete for %s", room_id)
|
logger.info("State resync complete for %s", room_id)
|
||||||
|
self._storage_controllers.state.notify_room_un_partial_stated(
|
||||||
|
room_id
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(faster_joins) update room stats and user directory?
|
# TODO(faster_joins) update room stats and user directory?
|
||||||
return
|
return
|
||||||
|
|
|
@ -217,7 +217,7 @@ class MessageHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
if membership == Membership.JOIN:
|
if membership == Membership.JOIN:
|
||||||
state_ids = await self.store.get_filtered_current_state_ids(
|
state_ids = await self._state_storage_controller.get_current_state_ids(
|
||||||
room_id, state_filter=state_filter
|
room_id, state_filter=state_filter
|
||||||
)
|
)
|
||||||
room_state = await self.store.get_events(state_ids.values())
|
room_state = await self.store.get_events(state_ids.values())
|
||||||
|
|
|
@ -134,6 +134,7 @@ class BasePresenceHandler(abc.ABC):
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.presence_router = hs.get_presence_router()
|
self.presence_router = hs.get_presence_router()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
@ -1348,7 +1349,10 @@ class PresenceHandler(BasePresenceHandler):
|
||||||
self._event_pos,
|
self._event_pos,
|
||||||
room_max_stream_ordering,
|
room_max_stream_ordering,
|
||||||
)
|
)
|
||||||
max_pos, deltas = await self.store.get_current_state_deltas(
|
(
|
||||||
|
max_pos,
|
||||||
|
deltas,
|
||||||
|
) = await self._storage_controllers.state.get_current_state_deltas(
|
||||||
self._event_pos, room_max_stream_ordering
|
self._event_pos, room_max_stream_ordering
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -87,6 +87,7 @@ class LoginDict(TypedDict):
|
||||||
class RegistrationHandler:
|
class RegistrationHandler:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
@ -528,7 +529,7 @@ class RegistrationHandler:
|
||||||
|
|
||||||
if requires_invite:
|
if requires_invite:
|
||||||
# If the server is in the room, check if the room is public.
|
# If the server is in the room, check if the room is public.
|
||||||
state = await self.store.get_filtered_current_state_ids(
|
state = await self._storage_controllers.state.get_current_state_ids(
|
||||||
room_id, StateFilter.from_types([(EventTypes.JoinRules, "")])
|
room_id, StateFilter.from_types([(EventTypes.JoinRules, "")])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -107,6 +107,7 @@ class EventContext:
|
||||||
class RoomCreationHandler:
|
class RoomCreationHandler:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
@ -480,9 +481,11 @@ class RoomCreationHandler:
|
||||||
if room_type == RoomTypes.SPACE:
|
if room_type == RoomTypes.SPACE:
|
||||||
types_to_copy.append((EventTypes.SpaceChild, None))
|
types_to_copy.append((EventTypes.SpaceChild, None))
|
||||||
|
|
||||||
old_room_state_ids = await self.store.get_filtered_current_state_ids(
|
old_room_state_ids = (
|
||||||
|
await self._storage_controllers.state.get_current_state_ids(
|
||||||
old_room_id, StateFilter.from_types(types_to_copy)
|
old_room_id, StateFilter.from_types(types_to_copy)
|
||||||
)
|
)
|
||||||
|
)
|
||||||
# map from event_id to BaseEvent
|
# map from event_id to BaseEvent
|
||||||
old_room_state_events = await self.store.get_events(old_room_state_ids.values())
|
old_room_state_events = await self.store.get_events(old_room_state_ids.values())
|
||||||
|
|
||||||
|
@ -558,9 +561,11 @@ class RoomCreationHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Transfer membership events
|
# Transfer membership events
|
||||||
old_room_member_state_ids = await self.store.get_filtered_current_state_ids(
|
old_room_member_state_ids = (
|
||||||
|
await self._storage_controllers.state.get_current_state_ids(
|
||||||
old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
|
old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# map from event_id to BaseEvent
|
# map from event_id to BaseEvent
|
||||||
old_room_member_state_events = await self.store.get_events(
|
old_room_member_state_events = await self.store.get_events(
|
||||||
|
|
|
@ -50,6 +50,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
|
||||||
class RoomListHandler:
|
class RoomListHandler:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
|
self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
|
||||||
self.response_cache: ResponseCache[
|
self.response_cache: ResponseCache[
|
||||||
|
@ -274,7 +275,7 @@ class RoomListHandler:
|
||||||
if aliases:
|
if aliases:
|
||||||
result["aliases"] = aliases
|
result["aliases"] = aliases
|
||||||
|
|
||||||
current_state_ids = await self.store.get_current_state_ids(
|
current_state_ids = await self._storage_controllers.state.get_current_state_ids(
|
||||||
room_id, on_invalidate=cache_context.invalidate
|
room_id, on_invalidate=cache_context.invalidate
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -68,6 +68,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.state_handler = hs.get_state_handler()
|
self.state_handler = hs.get_state_handler()
|
||||||
self.config = hs.config
|
self.config = hs.config
|
||||||
|
@ -994,7 +995,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
# If the host is in the room, but not one of the authorised hosts
|
# If the host is in the room, but not one of the authorised hosts
|
||||||
# for restricted join rules, a remote join must be used.
|
# for restricted join rules, a remote join must be used.
|
||||||
room_version = await self.store.get_room_version(room_id)
|
room_version = await self.store.get_room_version(room_id)
|
||||||
current_state_ids = await self.store.get_current_state_ids(room_id)
|
current_state_ids = await self._storage_controllers.state.get_current_state_ids(
|
||||||
|
room_id
|
||||||
|
)
|
||||||
|
|
||||||
# If restricted join rules are not being used, a local join can always
|
# If restricted join rules are not being used, a local join can always
|
||||||
# be used.
|
# be used.
|
||||||
|
|
|
@ -90,6 +90,7 @@ class RoomSummaryHandler:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self._event_auth_handler = hs.get_event_auth_handler()
|
self._event_auth_handler = hs.get_event_auth_handler()
|
||||||
self._store = hs.get_datastores().main
|
self._store = hs.get_datastores().main
|
||||||
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self._event_serializer = hs.get_event_client_serializer()
|
self._event_serializer = hs.get_event_client_serializer()
|
||||||
self._server_name = hs.hostname
|
self._server_name = hs.hostname
|
||||||
self._federation_client = hs.get_federation_client()
|
self._federation_client = hs.get_federation_client()
|
||||||
|
@ -537,7 +538,7 @@ class RoomSummaryHandler:
|
||||||
Returns:
|
Returns:
|
||||||
True if the room is accessible to the requesting user or server.
|
True if the room is accessible to the requesting user or server.
|
||||||
"""
|
"""
|
||||||
state_ids = await self._store.get_current_state_ids(room_id)
|
state_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
|
||||||
|
|
||||||
# If there's no state for the room, it isn't known.
|
# If there's no state for the room, it isn't known.
|
||||||
if not state_ids:
|
if not state_ids:
|
||||||
|
@ -702,7 +703,9 @@ class RoomSummaryHandler:
|
||||||
# there should always be an entry
|
# there should always be an entry
|
||||||
assert stats is not None, "unable to retrieve stats for %s" % (room_id,)
|
assert stats is not None, "unable to retrieve stats for %s" % (room_id,)
|
||||||
|
|
||||||
current_state_ids = await self._store.get_current_state_ids(room_id)
|
current_state_ids = await self._storage_controllers.state.get_current_state_ids(
|
||||||
|
room_id
|
||||||
|
)
|
||||||
create_event = await self._store.get_event(
|
create_event = await self._store.get_event(
|
||||||
current_state_ids[(EventTypes.Create, "")]
|
current_state_ids[(EventTypes.Create, "")]
|
||||||
)
|
)
|
||||||
|
@ -760,7 +763,9 @@ class RoomSummaryHandler:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# look for child rooms/spaces.
|
# look for child rooms/spaces.
|
||||||
current_state_ids = await self._store.get_current_state_ids(room_id)
|
current_state_ids = await self._storage_controllers.state.get_current_state_ids(
|
||||||
|
room_id
|
||||||
|
)
|
||||||
|
|
||||||
events = await self._store.get_events_as_list(
|
events = await self._store.get_events_as_list(
|
||||||
[
|
[
|
||||||
|
|
|
@ -40,6 +40,7 @@ class StatsHandler:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
@ -105,7 +106,10 @@ class StatsHandler:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Processing room stats %s->%s", self.pos, room_max_stream_ordering
|
"Processing room stats %s->%s", self.pos, room_max_stream_ordering
|
||||||
)
|
)
|
||||||
max_pos, deltas = await self.store.get_current_state_deltas(
|
(
|
||||||
|
max_pos,
|
||||||
|
deltas,
|
||||||
|
) = await self._storage_controllers.state.get_current_state_deltas(
|
||||||
self.pos, room_max_stream_ordering
|
self.pos, room_max_stream_ordering
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -506,9 +506,11 @@ class SyncHandler:
|
||||||
# ensure that we always include current state in the timeline
|
# ensure that we always include current state in the timeline
|
||||||
current_state_ids: FrozenSet[str] = frozenset()
|
current_state_ids: FrozenSet[str] = frozenset()
|
||||||
if any(e.is_state() for e in recents):
|
if any(e.is_state() for e in recents):
|
||||||
current_state_ids_map = await self.store.get_current_state_ids(
|
current_state_ids_map = (
|
||||||
|
await self._state_storage_controller.get_current_state_ids(
|
||||||
room_id
|
room_id
|
||||||
)
|
)
|
||||||
|
)
|
||||||
current_state_ids = frozenset(current_state_ids_map.values())
|
current_state_ids = frozenset(current_state_ids_map.values())
|
||||||
|
|
||||||
recents = await filter_events_for_client(
|
recents = await filter_events_for_client(
|
||||||
|
@ -574,8 +576,11 @@ class SyncHandler:
|
||||||
# ensure that we always include current state in the timeline
|
# ensure that we always include current state in the timeline
|
||||||
current_state_ids = frozenset()
|
current_state_ids = frozenset()
|
||||||
if any(e.is_state() for e in loaded_recents):
|
if any(e.is_state() for e in loaded_recents):
|
||||||
current_state_ids_map = await self.store.get_current_state_ids(
|
# FIXME(faster_joins): We use the partial state here as
|
||||||
room_id
|
# we don't want to block `/sync` on finishing a lazy join.
|
||||||
|
# Is this the correct way of doing it?
|
||||||
|
current_state_ids_map = (
|
||||||
|
await self.store.get_partial_current_state_ids(room_id)
|
||||||
)
|
)
|
||||||
current_state_ids = frozenset(current_state_ids_map.values())
|
current_state_ids = frozenset(current_state_ids_map.values())
|
||||||
|
|
||||||
|
|
|
@ -56,6 +56,7 @@ class UserDirectoryHandler(StateDeltasHandler):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
|
@ -174,7 +175,10 @@ class UserDirectoryHandler(StateDeltasHandler):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Processing user stats %s->%s", self.pos, room_max_stream_ordering
|
"Processing user stats %s->%s", self.pos, room_max_stream_ordering
|
||||||
)
|
)
|
||||||
max_pos, deltas = await self.store.get_current_state_deltas(
|
(
|
||||||
|
max_pos,
|
||||||
|
deltas,
|
||||||
|
) = await self._storage_controllers.state.get_current_state_deltas(
|
||||||
self.pos, room_max_stream_ordering
|
self.pos, room_max_stream_ordering
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -194,6 +194,7 @@ class ModuleApi:
|
||||||
self._store: Union[
|
self._store: Union[
|
||||||
DataStore, "GenericWorkerSlavedStore"
|
DataStore, "GenericWorkerSlavedStore"
|
||||||
] = hs.get_datastores().main
|
] = hs.get_datastores().main
|
||||||
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self._auth = hs.get_auth()
|
self._auth = hs.get_auth()
|
||||||
self._auth_handler = auth_handler
|
self._auth_handler = auth_handler
|
||||||
self._server_name = hs.hostname
|
self._server_name = hs.hostname
|
||||||
|
@ -911,7 +912,7 @@ class ModuleApi:
|
||||||
The filtered state events in the room.
|
The filtered state events in the room.
|
||||||
"""
|
"""
|
||||||
state_ids = yield defer.ensureDeferred(
|
state_ids = yield defer.ensureDeferred(
|
||||||
self._store.get_filtered_current_state_ids(
|
self._storage_controllers.state.get_current_state_ids(
|
||||||
room_id=room_id, state_filter=StateFilter.from_types(types)
|
room_id=room_id, state_filter=StateFilter.from_types(types)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -1289,20 +1290,16 @@ class ModuleApi:
|
||||||
# regardless of their state key
|
# regardless of their state key
|
||||||
]
|
]
|
||||||
"""
|
"""
|
||||||
|
state_filter = None
|
||||||
if event_filter:
|
if event_filter:
|
||||||
# If a filter was provided, turn it into a StateFilter and retrieve a filtered
|
# If a filter was provided, turn it into a StateFilter and retrieve a filtered
|
||||||
# view of the state.
|
# view of the state.
|
||||||
state_filter = StateFilter.from_types(event_filter)
|
state_filter = StateFilter.from_types(event_filter)
|
||||||
state_ids = await self._store.get_filtered_current_state_ids(
|
|
||||||
|
state_ids = await self._storage_controllers.state.get_current_state_ids(
|
||||||
room_id,
|
room_id,
|
||||||
state_filter,
|
state_filter,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# If no filter was provided, get the whole state. We could also reuse the call
|
|
||||||
# to get_filtered_current_state_ids above, with `state_filter = StateFilter.all()`,
|
|
||||||
# but get_filtered_current_state_ids isn't cached and `get_current_state_ids`
|
|
||||||
# is, so using the latter when we can is better for perf.
|
|
||||||
state_ids = await self._store.get_current_state_ids(room_id)
|
|
||||||
|
|
||||||
state_events = await self._store.get_events(state_ids.values())
|
state_events = await self._store.get_events(state_ids.values())
|
||||||
|
|
||||||
|
|
|
@ -255,7 +255,9 @@ class Mailer:
|
||||||
user_display_name = user_id
|
user_display_name = user_id
|
||||||
|
|
||||||
async def _fetch_room_state(room_id: str) -> None:
|
async def _fetch_room_state(room_id: str) -> None:
|
||||||
room_state = await self.store.get_current_state_ids(room_id)
|
room_state = await self._state_storage_controller.get_current_state_ids(
|
||||||
|
room_id
|
||||||
|
)
|
||||||
state_by_room[room_id] = room_state
|
state_by_room[room_id] = room_state
|
||||||
|
|
||||||
# Run at most 3 of these at once: sync does 10 at a time but email
|
# Run at most 3 of these at once: sync does 10 at a time but email
|
||||||
|
|
|
@ -418,6 +418,7 @@ class RoomStateRestServlet(RestServlet):
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self._event_serializer = hs.get_event_client_serializer()
|
self._event_serializer = hs.get_event_client_serializer()
|
||||||
|
|
||||||
|
@ -430,7 +431,7 @@ class RoomStateRestServlet(RestServlet):
|
||||||
if not ret:
|
if not ret:
|
||||||
raise NotFoundError("Room not found")
|
raise NotFoundError("Room not found")
|
||||||
|
|
||||||
event_ids = await self.store.get_current_state_ids(room_id)
|
event_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
|
||||||
events = await self.store.get_events(event_ids.values())
|
events = await self.store.get_events(event_ids.values())
|
||||||
now = self.clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
room_state = self._event_serializer.serialize_events(events.values(), now)
|
room_state = self._event_serializer.serialize_events(events.values(), now)
|
||||||
|
|
|
@ -77,7 +77,7 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||||
|
|
||||||
# Purge other caches based on room state.
|
# Purge other caches based on room state.
|
||||||
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
|
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
|
||||||
self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
|
self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,))
|
||||||
|
|
||||||
def _attempt_to_invalidate_cache(
|
def _attempt_to_invalidate_cache(
|
||||||
self, cache_name: str, key: Optional[Collection[Any]]
|
self, cache_name: str, key: Optional[Collection[Any]]
|
||||||
|
|
|
@ -18,7 +18,7 @@ from synapse.storage.controllers.persist_events import (
|
||||||
EventsPersistenceStorageController,
|
EventsPersistenceStorageController,
|
||||||
)
|
)
|
||||||
from synapse.storage.controllers.purge_events import PurgeEventsStorageController
|
from synapse.storage.controllers.purge_events import PurgeEventsStorageController
|
||||||
from synapse.storage.controllers.state import StateGroupStorageController
|
from synapse.storage.controllers.state import StateStorageController
|
||||||
from synapse.storage.databases import Databases
|
from synapse.storage.databases import Databases
|
||||||
from synapse.storage.databases.main import DataStore
|
from synapse.storage.databases.main import DataStore
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ class StorageControllers:
|
||||||
self.main = stores.main
|
self.main = stores.main
|
||||||
|
|
||||||
self.purge_events = PurgeEventsStorageController(hs, stores)
|
self.purge_events = PurgeEventsStorageController(hs, stores)
|
||||||
self.state = StateGroupStorageController(hs, stores)
|
self.state = StateStorageController(hs, stores)
|
||||||
|
|
||||||
self.persistence = None
|
self.persistence = None
|
||||||
if stores.persist_events:
|
if stores.persist_events:
|
||||||
|
|
|
@ -994,7 +994,7 @@ class EventsPersistenceStorageController:
|
||||||
|
|
||||||
Assumes that we are only persisting events for one room at a time.
|
Assumes that we are only persisting events for one room at a time.
|
||||||
"""
|
"""
|
||||||
existing_state = await self.main_store.get_current_state_ids(room_id)
|
existing_state = await self.main_store.get_partial_current_state_ids(room_id)
|
||||||
|
|
||||||
to_delete = [key for key in existing_state if key not in current_state]
|
to_delete = [key for key in existing_state if key not in current_state]
|
||||||
|
|
||||||
|
@ -1083,7 +1083,7 @@ class EventsPersistenceStorageController:
|
||||||
# The server will leave the room, so we go and find out which remote
|
# The server will leave the room, so we go and find out which remote
|
||||||
# users will still be joined when we leave.
|
# users will still be joined when we leave.
|
||||||
if current_state is None:
|
if current_state is None:
|
||||||
current_state = await self.main_store.get_current_state_ids(room_id)
|
current_state = await self.main_store.get_partial_current_state_ids(room_id)
|
||||||
current_state = dict(current_state)
|
current_state = dict(current_state)
|
||||||
for key in delta.to_delete:
|
for key in delta.to_delete:
|
||||||
current_state.pop(key, None)
|
current_state.pop(key, None)
|
||||||
|
|
|
@ -14,7 +14,9 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
|
Callable,
|
||||||
Collection,
|
Collection,
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
|
@ -24,9 +26,13 @@ from typing import (
|
||||||
Tuple,
|
Tuple,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
|
from synapse.storage.util.partial_state_events_tracker import (
|
||||||
|
PartialCurrentStateTracker,
|
||||||
|
PartialStateEventsTracker,
|
||||||
|
)
|
||||||
from synapse.types import MutableStateMap, StateMap
|
from synapse.types import MutableStateMap, StateMap
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -36,17 +42,27 @@ if TYPE_CHECKING:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class StateGroupStorageController:
|
class StateStorageController:
|
||||||
"""High level interface to fetching state for event."""
|
"""High level interface to fetching state for an event, or the current state
|
||||||
|
in a room.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer", stores: "Databases"):
|
def __init__(self, hs: "HomeServer", stores: "Databases"):
|
||||||
self._is_mine_id = hs.is_mine_id
|
self._is_mine_id = hs.is_mine_id
|
||||||
self.stores = stores
|
self.stores = stores
|
||||||
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
|
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
|
||||||
|
self._partial_state_room_tracker = PartialCurrentStateTracker(stores.main)
|
||||||
|
|
||||||
def notify_event_un_partial_stated(self, event_id: str) -> None:
|
def notify_event_un_partial_stated(self, event_id: str) -> None:
|
||||||
self._partial_state_events_tracker.notify_un_partial_stated(event_id)
|
self._partial_state_events_tracker.notify_un_partial_stated(event_id)
|
||||||
|
|
||||||
|
def notify_room_un_partial_stated(self, room_id: str) -> None:
|
||||||
|
"""Notify that the room no longer has any partial state.
|
||||||
|
|
||||||
|
Must be called after `DataStore.clear_partial_state_room`
|
||||||
|
"""
|
||||||
|
self._partial_state_room_tracker.notify_un_partial_stated(room_id)
|
||||||
|
|
||||||
async def get_state_group_delta(
|
async def get_state_group_delta(
|
||||||
self, state_group: int
|
self, state_group: int
|
||||||
) -> Tuple[Optional[int], Optional[StateMap[str]]]:
|
) -> Tuple[Optional[int], Optional[StateMap[str]]]:
|
||||||
|
@ -349,3 +365,93 @@ class StateGroupStorageController:
|
||||||
return await self.stores.state.store_state_group(
|
return await self.stores.state.store_state_group(
|
||||||
event_id, room_id, prev_group, delta_ids, current_state_ids
|
event_id, room_id, prev_group, delta_ids, current_state_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def get_current_state_ids(
|
||||||
|
self,
|
||||||
|
room_id: str,
|
||||||
|
state_filter: Optional[StateFilter] = None,
|
||||||
|
on_invalidate: Optional[Callable[[], None]] = None,
|
||||||
|
) -> StateMap[str]:
|
||||||
|
"""Get the current state event ids for a room based on the
|
||||||
|
current_state_events table.
|
||||||
|
|
||||||
|
If a state filter is given (that is not `StateFilter.all()`) the query
|
||||||
|
result is *not* cached.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: The room to get the state IDs of. state_filter: The state
|
||||||
|
filter used to fetch state from the
|
||||||
|
database.
|
||||||
|
on_invalidate: Callback for when the `get_current_state_ids` cache
|
||||||
|
for the room gets invalidated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The current state of the room.
|
||||||
|
"""
|
||||||
|
if not state_filter or state_filter.must_await_full_state(self._is_mine_id):
|
||||||
|
await self._partial_state_room_tracker.await_full_state(room_id)
|
||||||
|
|
||||||
|
if state_filter and not state_filter.is_full():
|
||||||
|
return await self.stores.main.get_partial_filtered_current_state_ids(
|
||||||
|
room_id, state_filter
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return await self.stores.main.get_partial_current_state_ids(
|
||||||
|
room_id, on_invalidate=on_invalidate
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
|
||||||
|
"""Get canonical alias for room, if any
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: The room ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The canonical alias, if any
|
||||||
|
"""
|
||||||
|
|
||||||
|
state = await self.get_current_state_ids(
|
||||||
|
room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
|
||||||
|
)
|
||||||
|
|
||||||
|
event_id = state.get((EventTypes.CanonicalAlias, ""))
|
||||||
|
if not event_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
event = await self.stores.main.get_event(event_id, allow_none=True)
|
||||||
|
if not event:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return event.content.get("canonical_alias")
|
||||||
|
|
||||||
|
async def get_current_state_deltas(
|
||||||
|
self, prev_stream_id: int, max_stream_id: int
|
||||||
|
) -> Tuple[int, List[Dict[str, Any]]]:
|
||||||
|
"""Fetch a list of room state changes since the given stream id
|
||||||
|
|
||||||
|
Each entry in the result contains the following fields:
|
||||||
|
- stream_id (int)
|
||||||
|
- room_id (str)
|
||||||
|
- type (str): event type
|
||||||
|
- state_key (str):
|
||||||
|
- event_id (str|None): new event_id for this state key. None if the
|
||||||
|
state has been deleted.
|
||||||
|
- prev_event_id (str|None): previous event_id for this state key. None
|
||||||
|
if it's new state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prev_stream_id: point to get changes since (exclusive)
|
||||||
|
max_stream_id: the point that we know has been correctly persisted
|
||||||
|
- ie, an upper limit to return changes from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple consisting of:
|
||||||
|
- the stream id which these results go up to
|
||||||
|
- list of current_state_delta_stream rows. If it is empty, we are
|
||||||
|
up to date.
|
||||||
|
"""
|
||||||
|
# FIXME(faster_joins): what do we do here?
|
||||||
|
|
||||||
|
return await self.stores.main.get_partial_current_state_deltas(
|
||||||
|
prev_stream_id, max_stream_id
|
||||||
|
)
|
||||||
|
|
|
@ -1139,6 +1139,24 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||||
keyvalues={"room_id": room_id},
|
keyvalues={"room_id": room_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def is_partial_state_room(self, room_id: str) -> bool:
|
||||||
|
"""Checks if this room has partial state.
|
||||||
|
|
||||||
|
Returns true if this is a "partial-state" room, which means that the state
|
||||||
|
at events in the room, and `current_state_events`, may not yet be
|
||||||
|
complete.
|
||||||
|
"""
|
||||||
|
|
||||||
|
entry = await self.db_pool.simple_select_one_onecol(
|
||||||
|
table="partial_state_rooms",
|
||||||
|
keyvalues={"room_id": room_id},
|
||||||
|
retcol="room_id",
|
||||||
|
allow_none=True,
|
||||||
|
desc="is_partial_state_room",
|
||||||
|
)
|
||||||
|
|
||||||
|
return entry is not None
|
||||||
|
|
||||||
|
|
||||||
class _BackgroundUpdates:
|
class _BackgroundUpdates:
|
||||||
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
|
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
|
||||||
|
|
|
@ -242,7 +242,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
Raises:
|
Raises:
|
||||||
NotFoundError if the room is unknown
|
NotFoundError if the room is unknown
|
||||||
"""
|
"""
|
||||||
state_ids = await self.get_current_state_ids(room_id)
|
state_ids = await self.get_partial_current_state_ids(room_id)
|
||||||
|
|
||||||
if not state_ids:
|
if not state_ids:
|
||||||
raise NotFoundError(f"Current state for room {room_id} is empty")
|
raise NotFoundError(f"Current state for room {room_id} is empty")
|
||||||
|
@ -258,10 +258,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
return create_event
|
return create_event
|
||||||
|
|
||||||
@cached(max_entries=100000, iterable=True)
|
@cached(max_entries=100000, iterable=True)
|
||||||
async def get_current_state_ids(self, room_id: str) -> StateMap[str]:
|
async def get_partial_current_state_ids(self, room_id: str) -> StateMap[str]:
|
||||||
"""Get the current state event ids for a room based on the
|
"""Get the current state event ids for a room based on the
|
||||||
current_state_events table.
|
current_state_events table.
|
||||||
|
|
||||||
|
This may be the partial state if we're lazy joining the room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id: The room to get the state IDs of.
|
room_id: The room to get the state IDs of.
|
||||||
|
|
||||||
|
@ -280,17 +282,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
|
return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_current_state_ids", _get_current_state_ids_txn
|
"get_partial_current_state_ids", _get_current_state_ids_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
# FIXME: how should this be cached?
|
# FIXME: how should this be cached?
|
||||||
async def get_filtered_current_state_ids(
|
async def get_partial_filtered_current_state_ids(
|
||||||
self, room_id: str, state_filter: Optional[StateFilter] = None
|
self, room_id: str, state_filter: Optional[StateFilter] = None
|
||||||
) -> StateMap[str]:
|
) -> StateMap[str]:
|
||||||
"""Get the current state event of a given type for a room based on the
|
"""Get the current state event of a given type for a room based on the
|
||||||
current_state_events table. This may not be as up-to-date as the result
|
current_state_events table. This may not be as up-to-date as the result
|
||||||
of doing a fresh state resolution as per state_handler.get_current_state
|
of doing a fresh state resolution as per state_handler.get_current_state
|
||||||
|
|
||||||
|
This may be the partial state if we're lazy joining the room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id
|
room_id
|
||||||
state_filter: The state filter used to fetch state
|
state_filter: The state filter used to fetch state
|
||||||
|
@ -306,7 +310,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
|
|
||||||
if not where_clause:
|
if not where_clause:
|
||||||
# We delegate to the cached version
|
# We delegate to the cached version
|
||||||
return await self.get_current_state_ids(room_id)
|
return await self.get_partial_current_state_ids(room_id)
|
||||||
|
|
||||||
def _get_filtered_current_state_ids_txn(
|
def _get_filtered_current_state_ids_txn(
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
|
@ -334,30 +338,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
|
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
|
|
||||||
"""Get canonical alias for room, if any
|
|
||||||
|
|
||||||
Args:
|
|
||||||
room_id: The room ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The canonical alias, if any
|
|
||||||
"""
|
|
||||||
|
|
||||||
state = await self.get_filtered_current_state_ids(
|
|
||||||
room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
|
|
||||||
)
|
|
||||||
|
|
||||||
event_id = state.get((EventTypes.CanonicalAlias, ""))
|
|
||||||
if not event_id:
|
|
||||||
return None
|
|
||||||
|
|
||||||
event = await self.get_event(event_id, allow_none=True)
|
|
||||||
if not event:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return event.content.get("canonical_alias")
|
|
||||||
|
|
||||||
@cached(max_entries=50000)
|
@cached(max_entries=50000)
|
||||||
async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
|
async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
|
||||||
return await self.db_pool.simple_select_one_onecol(
|
return await self.db_pool.simple_select_one_onecol(
|
||||||
|
|
|
@ -27,7 +27,7 @@ class StateDeltasStore(SQLBaseStore):
|
||||||
# attribute. TODO: can we get static analysis to enforce this?
|
# attribute. TODO: can we get static analysis to enforce this?
|
||||||
_curr_state_delta_stream_cache: StreamChangeCache
|
_curr_state_delta_stream_cache: StreamChangeCache
|
||||||
|
|
||||||
async def get_current_state_deltas(
|
async def get_partial_current_state_deltas(
|
||||||
self, prev_stream_id: int, max_stream_id: int
|
self, prev_stream_id: int, max_stream_id: int
|
||||||
) -> Tuple[int, List[Dict[str, Any]]]:
|
) -> Tuple[int, List[Dict[str, Any]]]:
|
||||||
"""Fetch a list of room state changes since the given stream id
|
"""Fetch a list of room state changes since the given stream id
|
||||||
|
@ -42,6 +42,8 @@ class StateDeltasStore(SQLBaseStore):
|
||||||
- prev_event_id (str|None): previous event_id for this state key. None
|
- prev_event_id (str|None): previous event_id for this state key. None
|
||||||
if it's new state.
|
if it's new state.
|
||||||
|
|
||||||
|
This may be the partial state if we're lazy joining the room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prev_stream_id: point to get changes since (exclusive)
|
prev_stream_id: point to get changes since (exclusive)
|
||||||
max_stream_id: the point that we know has been correctly persisted
|
max_stream_id: the point that we know has been correctly persisted
|
||||||
|
|
|
@ -441,7 +441,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||||
(EventTypes.RoomHistoryVisibility, ""),
|
(EventTypes.RoomHistoryVisibility, ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined]
|
# Getting the partial state is fine, as we're not looking at membership
|
||||||
|
# events.
|
||||||
|
current_state_ids = await self.get_partial_filtered_current_state_ids( # type: ignore[attr-defined]
|
||||||
room_id, StateFilter.from_types(types_to_filter)
|
room_id, StateFilter.from_types(types_to_filter)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ from twisted.internet.defer import Deferred
|
||||||
|
|
||||||
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
||||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||||
|
from synapse.storage.databases.main.room import RoomWorkerStore
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -118,3 +119,62 @@ class PartialStateEventsTracker:
|
||||||
observer_set.discard(observer)
|
observer_set.discard(observer)
|
||||||
if not observer_set:
|
if not observer_set:
|
||||||
del self._observers[event_id]
|
del self._observers[event_id]
|
||||||
|
|
||||||
|
|
||||||
|
class PartialCurrentStateTracker:
|
||||||
|
"""Keeps track of which rooms have partial state, after partial-state joins"""
|
||||||
|
|
||||||
|
def __init__(self, store: RoomWorkerStore):
|
||||||
|
self._store = store
|
||||||
|
|
||||||
|
# a map from room id to a set of Deferreds which are waiting for that room to be
|
||||||
|
# un-partial-stated.
|
||||||
|
self._observers: Dict[str, Set[Deferred[None]]] = defaultdict(set)
|
||||||
|
|
||||||
|
def notify_un_partial_stated(self, room_id: str) -> None:
|
||||||
|
"""Notify that we now have full current state for a given room
|
||||||
|
|
||||||
|
Unblocks any callers to await_full_state() for that room.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: the room that now has full current state.
|
||||||
|
"""
|
||||||
|
observers = self._observers.pop(room_id, None)
|
||||||
|
if not observers:
|
||||||
|
return
|
||||||
|
logger.info(
|
||||||
|
"Notifying %i things waiting for un-partial-stating of room %s",
|
||||||
|
len(observers),
|
||||||
|
room_id,
|
||||||
|
)
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
for o in observers:
|
||||||
|
o.callback(None)
|
||||||
|
|
||||||
|
async def await_full_state(self, room_id: str) -> None:
|
||||||
|
# We add the deferred immediately so that the DB call to check for
|
||||||
|
# partial state doesn't race when we unpartial the room.
|
||||||
|
d: Deferred[None] = Deferred()
|
||||||
|
self._observers.setdefault(room_id, set()).add(d)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if the room has partial current state or not.
|
||||||
|
has_partial_state = await self._store.is_partial_state_room(room_id)
|
||||||
|
if not has_partial_state:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Awaiting un-partial-stating of room %s",
|
||||||
|
room_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await make_deferred_yieldable(d)
|
||||||
|
|
||||||
|
logger.info("Room has un-partial-stated")
|
||||||
|
finally:
|
||||||
|
# Remove the added observer, and remove the room entry if its empty.
|
||||||
|
ds = self._observers.get(room_id)
|
||||||
|
if ds is not None:
|
||||||
|
ds.discard(d)
|
||||||
|
if not ds:
|
||||||
|
self._observers.pop(room_id, None)
|
||||||
|
|
|
@ -237,7 +237,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
|
||||||
)
|
)
|
||||||
current_state = self.get_success(
|
current_state = self.get_success(
|
||||||
self.store.get_events_as_list(
|
self.store.get_events_as_list(
|
||||||
(self.get_success(self.store.get_current_state_ids(room_id))).values()
|
(
|
||||||
|
self.get_success(self.store.get_partial_current_state_ids(room_id))
|
||||||
|
).values()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -512,7 +514,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
|
||||||
self.get_success(d)
|
self.get_success(d)
|
||||||
|
|
||||||
# sanity-check: the room should show that the new user is a member
|
# sanity-check: the room should show that the new user is a member
|
||||||
r = self.get_success(self.store.get_current_state_ids(room_id))
|
r = self.get_success(self.store.get_partial_current_state_ids(room_id))
|
||||||
self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id)
|
self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id)
|
||||||
|
|
||||||
return join_event
|
return join_event
|
||||||
|
|
|
@ -91,7 +91,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
|
||||||
event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join")
|
event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join")
|
||||||
)
|
)
|
||||||
|
|
||||||
initial_state_map = self.get_success(main_store.get_current_state_ids(room_id))
|
initial_state_map = self.get_success(
|
||||||
|
main_store.get_partial_current_state_ids(room_id)
|
||||||
|
)
|
||||||
|
|
||||||
auth_event_ids = [
|
auth_event_ids = [
|
||||||
initial_state_map[("m.room.create", "")],
|
initial_state_map[("m.room.create", "")],
|
||||||
|
|
|
@ -146,7 +146,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.datastore.get_current_state_deltas = Mock(return_value=(0, None))
|
self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None))
|
||||||
|
|
||||||
self.datastore.get_to_device_stream_token = lambda: 0
|
self.datastore.get_to_device_stream_token = lambda: 0
|
||||||
self.datastore.get_new_device_msgs_for_remote = (
|
self.datastore.get_new_device_msgs_for_remote = (
|
||||||
|
|
|
@ -249,7 +249,9 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
new_space_id = channel.json_body["replacement_room"]
|
new_space_id = channel.json_body["replacement_room"]
|
||||||
|
|
||||||
state_ids = self.get_success(self.store.get_current_state_ids(new_space_id))
|
state_ids = self.get_success(
|
||||||
|
self.store.get_partial_current_state_ids(new_space_id)
|
||||||
|
)
|
||||||
|
|
||||||
# Ensure the new room is still a space.
|
# Ensure the new room is still a space.
|
||||||
create_event = self.get_success(
|
create_event = self.get_success(
|
||||||
|
@ -284,7 +286,9 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
new_room_id = channel.json_body["replacement_room"]
|
new_room_id = channel.json_body["replacement_room"]
|
||||||
|
|
||||||
state_ids = self.get_success(self.store.get_current_state_ids(new_room_id))
|
state_ids = self.get_success(
|
||||||
|
self.store.get_partial_current_state_ids(new_room_id)
|
||||||
|
)
|
||||||
|
|
||||||
# Ensure the new room is the same type as the old room.
|
# Ensure the new room is the same type as the old room.
|
||||||
create_event = self.get_success(
|
create_event = self.get_success(
|
||||||
|
|
|
@ -17,8 +17,12 @@ from unittest import mock
|
||||||
|
|
||||||
from twisted.internet.defer import CancelledError, ensureDeferred
|
from twisted.internet.defer import CancelledError, ensureDeferred
|
||||||
|
|
||||||
from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
|
from synapse.storage.util.partial_state_events_tracker import (
|
||||||
|
PartialCurrentStateTracker,
|
||||||
|
PartialStateEventsTracker,
|
||||||
|
)
|
||||||
|
|
||||||
|
from tests.test_utils import make_awaitable
|
||||||
from tests.unittest import TestCase
|
from tests.unittest import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
@ -115,3 +119,56 @@ class PartialStateEventsTrackerTestCase(TestCase):
|
||||||
|
|
||||||
self.tracker.notify_un_partial_stated("event1")
|
self.tracker.notify_un_partial_stated("event1")
|
||||||
self.successResultOf(d2)
|
self.successResultOf(d2)
|
||||||
|
|
||||||
|
|
||||||
|
class PartialCurrentStateTrackerTestCase(TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.mock_store = mock.Mock(spec_set=["is_partial_state_room"])
|
||||||
|
|
||||||
|
self.tracker = PartialCurrentStateTracker(self.mock_store)
|
||||||
|
|
||||||
|
def test_does_not_block_for_full_state_rooms(self):
|
||||||
|
self.mock_store.is_partial_state_room.return_value = make_awaitable(False)
|
||||||
|
|
||||||
|
self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
|
||||||
|
|
||||||
|
def test_blocks_for_partial_room_state(self):
|
||||||
|
self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
|
||||||
|
|
||||||
|
d = ensureDeferred(self.tracker.await_full_state("room_id"))
|
||||||
|
|
||||||
|
# there should be no result yet
|
||||||
|
self.assertNoResult(d)
|
||||||
|
|
||||||
|
# notifying that the room has been de-partial-stated should unblock
|
||||||
|
self.tracker.notify_un_partial_stated("room_id")
|
||||||
|
self.successResultOf(d)
|
||||||
|
|
||||||
|
def test_un_partial_state_race(self):
|
||||||
|
# We should correctly handle race between awaiting the state and us
|
||||||
|
# un-partialling the state
|
||||||
|
async def is_partial_state_room(events):
|
||||||
|
self.tracker.notify_un_partial_stated("room_id")
|
||||||
|
return True
|
||||||
|
|
||||||
|
self.mock_store.is_partial_state_room.side_effect = is_partial_state_room
|
||||||
|
|
||||||
|
self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
|
||||||
|
|
||||||
|
def test_cancellation(self):
|
||||||
|
self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
|
||||||
|
|
||||||
|
d1 = ensureDeferred(self.tracker.await_full_state("room_id"))
|
||||||
|
self.assertNoResult(d1)
|
||||||
|
|
||||||
|
d2 = ensureDeferred(self.tracker.await_full_state("room_id"))
|
||||||
|
self.assertNoResult(d2)
|
||||||
|
|
||||||
|
d1.cancel()
|
||||||
|
self.assertFailure(d1, CancelledError)
|
||||||
|
|
||||||
|
# d2 should still be waiting!
|
||||||
|
self.assertNoResult(d2)
|
||||||
|
|
||||||
|
self.tracker.notify_un_partial_stated("room_id")
|
||||||
|
self.successResultOf(d2)
|
||||||
|
|
Loading…
Reference in a new issue