mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-29 15:39:00 +03:00
Only store data in caches, not "smart" objects (#9845)
This commit is contained in:
parent
51a20914a8
commit
3853a7edfc
3 changed files with 180 additions and 139 deletions
1
changelog.d/9845.misc
Normal file
1
changelog.d/9845.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Only store the raw data in the in-memory caches, rather than objects that include references to e.g. the data stores.
|
|
@ -106,6 +106,10 @@ class BulkPushRuleEvaluator:
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
|
# Used by `RulesForRoom` to ensure only one thing mutates the cache at a
|
||||||
|
# time. Keyed off room_id.
|
||||||
|
self._rules_linearizer = Linearizer(name="rules_for_room")
|
||||||
|
|
||||||
self.room_push_rule_cache_metrics = register_cache(
|
self.room_push_rule_cache_metrics = register_cache(
|
||||||
"cache",
|
"cache",
|
||||||
"room_push_rule_cache",
|
"room_push_rule_cache",
|
||||||
|
@ -123,7 +127,16 @@ class BulkPushRuleEvaluator:
|
||||||
dict of user_id -> push_rules
|
dict of user_id -> push_rules
|
||||||
"""
|
"""
|
||||||
room_id = event.room_id
|
room_id = event.room_id
|
||||||
rules_for_room = self._get_rules_for_room(room_id)
|
|
||||||
|
rules_for_room_data = self._get_rules_for_room(room_id)
|
||||||
|
rules_for_room = RulesForRoom(
|
||||||
|
hs=self.hs,
|
||||||
|
room_id=room_id,
|
||||||
|
rules_for_room_cache=self._get_rules_for_room.cache,
|
||||||
|
room_push_rule_cache_metrics=self.room_push_rule_cache_metrics,
|
||||||
|
linearizer=self._rules_linearizer,
|
||||||
|
cached_data=rules_for_room_data,
|
||||||
|
)
|
||||||
|
|
||||||
rules_by_user = await rules_for_room.get_rules(event, context)
|
rules_by_user = await rules_for_room.get_rules(event, context)
|
||||||
|
|
||||||
|
@ -142,17 +155,12 @@ class BulkPushRuleEvaluator:
|
||||||
return rules_by_user
|
return rules_by_user
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def _get_rules_for_room(self, room_id: str) -> "RulesForRoom":
|
def _get_rules_for_room(self, room_id: str) -> "RulesForRoomData":
|
||||||
"""Get the current RulesForRoom object for the given room id"""
|
"""Get the current RulesForRoomData object for the given room id"""
|
||||||
# It's important that RulesForRoom gets added to self._get_rules_for_room.cache
|
# It's important that the RulesForRoomData object gets added to self._get_rules_for_room.cache
|
||||||
# before any lookup methods get called on it as otherwise there may be
|
# before any lookup methods get called on it as otherwise there may be
|
||||||
# a race if invalidate_all gets called (which assumes its in the cache)
|
# a race if invalidate_all gets called (which assumes its in the cache)
|
||||||
return RulesForRoom(
|
return RulesForRoomData()
|
||||||
self.hs,
|
|
||||||
room_id,
|
|
||||||
self._get_rules_for_room.cache,
|
|
||||||
self.room_push_rule_cache_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _get_power_levels_and_sender_level(
|
async def _get_power_levels_and_sender_level(
|
||||||
self, event: EventBase, context: EventContext
|
self, event: EventBase, context: EventContext
|
||||||
|
@ -282,11 +290,49 @@ def _condition_checker(
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True)
|
||||||
|
class RulesForRoomData:
|
||||||
|
"""The data stored in the cache by `RulesForRoom`.
|
||||||
|
|
||||||
|
We don't store `RulesForRoom` directly in the cache as we want our caches to
|
||||||
|
*only* include data, and not references to e.g. the data stores.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# event_id -> (user_id, state)
|
||||||
|
member_map = attr.ib(type=Dict[str, Tuple[str, str]], factory=dict)
|
||||||
|
# user_id -> rules
|
||||||
|
rules_by_user = attr.ib(type=Dict[str, List[Dict[str, dict]]], factory=dict)
|
||||||
|
|
||||||
|
# The last state group we updated the caches for. If the state_group of
|
||||||
|
# a new event comes along, we know that we can just return the cached
|
||||||
|
# result.
|
||||||
|
# On invalidation of the rules themselves (if the user changes them),
|
||||||
|
# we invalidate everything and set state_group to `object()`
|
||||||
|
state_group = attr.ib(type=Union[object, int], factory=object)
|
||||||
|
|
||||||
|
# A sequence number to keep track of when we're allowed to update the
|
||||||
|
# cache. We bump the sequence number when we invalidate the cache. If
|
||||||
|
# the sequence number changes while we're calculating stuff we should
|
||||||
|
# not update the cache with it.
|
||||||
|
sequence = attr.ib(type=int, default=0)
|
||||||
|
|
||||||
|
# A cache of user_ids that we *know* aren't interesting, e.g. user_ids
|
||||||
|
# owned by AS's, or remote users, etc. (I.e. users we will never need to
|
||||||
|
# calculate push for)
|
||||||
|
# These never need to be invalidated as we will never set up push for
|
||||||
|
# them.
|
||||||
|
uninteresting_user_set = attr.ib(type=Set[str], factory=set)
|
||||||
|
|
||||||
|
|
||||||
class RulesForRoom:
|
class RulesForRoom:
|
||||||
"""Caches push rules for users in a room.
|
"""Caches push rules for users in a room.
|
||||||
|
|
||||||
This efficiently handles users joining/leaving the room by not invalidating
|
This efficiently handles users joining/leaving the room by not invalidating
|
||||||
the entire cache for the room.
|
the entire cache for the room.
|
||||||
|
|
||||||
|
A new instance is constructed for each call to
|
||||||
|
`BulkPushRuleEvaluator._get_rules_for_event`, with the cached data from
|
||||||
|
previous calls passed in.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -295,6 +341,8 @@ class RulesForRoom:
|
||||||
room_id: str,
|
room_id: str,
|
||||||
rules_for_room_cache: LruCache,
|
rules_for_room_cache: LruCache,
|
||||||
room_push_rule_cache_metrics: CacheMetric,
|
room_push_rule_cache_metrics: CacheMetric,
|
||||||
|
linearizer: Linearizer,
|
||||||
|
cached_data: RulesForRoomData,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -303,38 +351,21 @@ class RulesForRoom:
|
||||||
rules_for_room_cache: The cache object that caches these
|
rules_for_room_cache: The cache object that caches these
|
||||||
RoomsForUser objects.
|
RoomsForUser objects.
|
||||||
room_push_rule_cache_metrics: The metrics object
|
room_push_rule_cache_metrics: The metrics object
|
||||||
|
linearizer: The linearizer used to ensure only one thing mutates
|
||||||
|
the cache at a time. Keyed off room_id
|
||||||
|
cached_data: Cached data from previous calls to `self.get_rules`,
|
||||||
|
can be mutated.
|
||||||
"""
|
"""
|
||||||
self.room_id = room_id
|
self.room_id = room_id
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.room_push_rule_cache_metrics = room_push_rule_cache_metrics
|
self.room_push_rule_cache_metrics = room_push_rule_cache_metrics
|
||||||
|
|
||||||
self.linearizer = Linearizer(name="rules_for_room")
|
# Used to ensure only one thing mutates the cache at a time. Keyed off
|
||||||
|
# room_id.
|
||||||
|
self.linearizer = linearizer
|
||||||
|
|
||||||
# event_id -> (user_id, state)
|
self.data = cached_data
|
||||||
self.member_map = {} # type: Dict[str, Tuple[str, str]]
|
|
||||||
# user_id -> rules
|
|
||||||
self.rules_by_user = {} # type: Dict[str, List[Dict[str, dict]]]
|
|
||||||
|
|
||||||
# The last state group we updated the caches for. If the state_group of
|
|
||||||
# a new event comes along, we know that we can just return the cached
|
|
||||||
# result.
|
|
||||||
# On invalidation of the rules themselves (if the user changes them),
|
|
||||||
# we invalidate everything and set state_group to `object()`
|
|
||||||
self.state_group = object()
|
|
||||||
|
|
||||||
# A sequence number to keep track of when we're allowed to update the
|
|
||||||
# cache. We bump the sequence number when we invalidate the cache. If
|
|
||||||
# the sequence number changes while we're calculating stuff we should
|
|
||||||
# not update the cache with it.
|
|
||||||
self.sequence = 0
|
|
||||||
|
|
||||||
# A cache of user_ids that we *know* aren't interesting, e.g. user_ids
|
|
||||||
# owned by AS's, or remote users, etc. (I.e. users we will never need to
|
|
||||||
# calculate push for)
|
|
||||||
# These never need to be invalidated as we will never set up push for
|
|
||||||
# them.
|
|
||||||
self.uninteresting_user_set = set() # type: Set[str]
|
|
||||||
|
|
||||||
# We need to be clever on the invalidating caches callbacks, as
|
# We need to be clever on the invalidating caches callbacks, as
|
||||||
# otherwise the invalidation callback holds a reference to the object,
|
# otherwise the invalidation callback holds a reference to the object,
|
||||||
|
@ -352,25 +383,25 @@ class RulesForRoom:
|
||||||
"""
|
"""
|
||||||
state_group = context.state_group
|
state_group = context.state_group
|
||||||
|
|
||||||
if state_group and self.state_group == state_group:
|
if state_group and self.data.state_group == state_group:
|
||||||
logger.debug("Using cached rules for %r", self.room_id)
|
logger.debug("Using cached rules for %r", self.room_id)
|
||||||
self.room_push_rule_cache_metrics.inc_hits()
|
self.room_push_rule_cache_metrics.inc_hits()
|
||||||
return self.rules_by_user
|
return self.data.rules_by_user
|
||||||
|
|
||||||
with (await self.linearizer.queue(())):
|
with (await self.linearizer.queue(self.room_id)):
|
||||||
if state_group and self.state_group == state_group:
|
if state_group and self.data.state_group == state_group:
|
||||||
logger.debug("Using cached rules for %r", self.room_id)
|
logger.debug("Using cached rules for %r", self.room_id)
|
||||||
self.room_push_rule_cache_metrics.inc_hits()
|
self.room_push_rule_cache_metrics.inc_hits()
|
||||||
return self.rules_by_user
|
return self.data.rules_by_user
|
||||||
|
|
||||||
self.room_push_rule_cache_metrics.inc_misses()
|
self.room_push_rule_cache_metrics.inc_misses()
|
||||||
|
|
||||||
ret_rules_by_user = {}
|
ret_rules_by_user = {}
|
||||||
missing_member_event_ids = {}
|
missing_member_event_ids = {}
|
||||||
if state_group and self.state_group == context.prev_group:
|
if state_group and self.data.state_group == context.prev_group:
|
||||||
# If we have a simple delta then we can reuse most of the previous
|
# If we have a simple delta then we can reuse most of the previous
|
||||||
# results.
|
# results.
|
||||||
ret_rules_by_user = self.rules_by_user
|
ret_rules_by_user = self.data.rules_by_user
|
||||||
current_state_ids = context.delta_ids
|
current_state_ids = context.delta_ids
|
||||||
|
|
||||||
push_rules_delta_state_cache_metric.inc_hits()
|
push_rules_delta_state_cache_metric.inc_hits()
|
||||||
|
@ -393,24 +424,24 @@ class RulesForRoom:
|
||||||
if typ != EventTypes.Member:
|
if typ != EventTypes.Member:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if user_id in self.uninteresting_user_set:
|
if user_id in self.data.uninteresting_user_set:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not self.is_mine_id(user_id):
|
if not self.is_mine_id(user_id):
|
||||||
self.uninteresting_user_set.add(user_id)
|
self.data.uninteresting_user_set.add(user_id)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if self.store.get_if_app_services_interested_in_user(user_id):
|
if self.store.get_if_app_services_interested_in_user(user_id):
|
||||||
self.uninteresting_user_set.add(user_id)
|
self.data.uninteresting_user_set.add(user_id)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
event_id = current_state_ids[key]
|
event_id = current_state_ids[key]
|
||||||
|
|
||||||
res = self.member_map.get(event_id, None)
|
res = self.data.member_map.get(event_id, None)
|
||||||
if res:
|
if res:
|
||||||
user_id, state = res
|
user_id, state = res
|
||||||
if state == Membership.JOIN:
|
if state == Membership.JOIN:
|
||||||
rules = self.rules_by_user.get(user_id, None)
|
rules = self.data.rules_by_user.get(user_id, None)
|
||||||
if rules:
|
if rules:
|
||||||
ret_rules_by_user[user_id] = rules
|
ret_rules_by_user[user_id] = rules
|
||||||
continue
|
continue
|
||||||
|
@ -430,7 +461,7 @@ class RulesForRoom:
|
||||||
else:
|
else:
|
||||||
# The push rules didn't change but lets update the cache anyway
|
# The push rules didn't change but lets update the cache anyway
|
||||||
self.update_cache(
|
self.update_cache(
|
||||||
self.sequence,
|
self.data.sequence,
|
||||||
members={}, # There were no membership changes
|
members={}, # There were no membership changes
|
||||||
rules_by_user=ret_rules_by_user,
|
rules_by_user=ret_rules_by_user,
|
||||||
state_group=state_group,
|
state_group=state_group,
|
||||||
|
@ -461,7 +492,7 @@ class RulesForRoom:
|
||||||
for. Used when updating the cache.
|
for. Used when updating the cache.
|
||||||
event: The event we are currently computing push rules for.
|
event: The event we are currently computing push rules for.
|
||||||
"""
|
"""
|
||||||
sequence = self.sequence
|
sequence = self.data.sequence
|
||||||
|
|
||||||
rows = await self.store.get_membership_from_event_ids(member_event_ids.values())
|
rows = await self.store.get_membership_from_event_ids(member_event_ids.values())
|
||||||
|
|
||||||
|
@ -501,23 +532,11 @@ class RulesForRoom:
|
||||||
|
|
||||||
self.update_cache(sequence, members, ret_rules_by_user, state_group)
|
self.update_cache(sequence, members, ret_rules_by_user, state_group)
|
||||||
|
|
||||||
def invalidate_all(self) -> None:
|
|
||||||
# Note: Don't hand this function directly to an invalidation callback
|
|
||||||
# as it keeps a reference to self and will stop this instance from being
|
|
||||||
# GC'd if it gets dropped from the rules_to_user cache. Instead use
|
|
||||||
# `self.invalidate_all_cb`
|
|
||||||
logger.debug("Invalidating RulesForRoom for %r", self.room_id)
|
|
||||||
self.sequence += 1
|
|
||||||
self.state_group = object()
|
|
||||||
self.member_map = {}
|
|
||||||
self.rules_by_user = {}
|
|
||||||
push_rules_invalidation_counter.inc()
|
|
||||||
|
|
||||||
def update_cache(self, sequence, members, rules_by_user, state_group) -> None:
|
def update_cache(self, sequence, members, rules_by_user, state_group) -> None:
|
||||||
if sequence == self.sequence:
|
if sequence == self.data.sequence:
|
||||||
self.member_map.update(members)
|
self.data.member_map.update(members)
|
||||||
self.rules_by_user = rules_by_user
|
self.data.rules_by_user = rules_by_user
|
||||||
self.state_group = state_group
|
self.data.state_group = state_group
|
||||||
|
|
||||||
|
|
||||||
@attr.attrs(slots=True, frozen=True)
|
@attr.attrs(slots=True, frozen=True)
|
||||||
|
@ -535,6 +554,10 @@ class _Invalidation:
|
||||||
room_id = attr.ib(type=str)
|
room_id = attr.ib(type=str)
|
||||||
|
|
||||||
def __call__(self) -> None:
|
def __call__(self) -> None:
|
||||||
rules = self.cache.get(self.room_id, None, update_metrics=False)
|
rules_data = self.cache.get(self.room_id, None, update_metrics=False)
|
||||||
if rules:
|
if rules_data:
|
||||||
rules.invalidate_all()
|
rules_data.sequence += 1
|
||||||
|
rules_data.state_group = object()
|
||||||
|
rules_data.member_map = {}
|
||||||
|
rules_data.rules_by_user = {}
|
||||||
|
push_rules_invalidation_counter.inc()
|
||||||
|
|
|
@ -23,8 +23,11 @@ from typing import (
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
|
@ -43,7 +46,7 @@ from synapse.storage.roommember import (
|
||||||
ProfileInfo,
|
ProfileInfo,
|
||||||
RoomsForUser,
|
RoomsForUser,
|
||||||
)
|
)
|
||||||
from synapse.types import PersistedEventPosition, get_domain_from_id
|
from synapse.types import PersistedEventPosition, StateMap, get_domain_from_id
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.caches import intern_string
|
from synapse.util.caches import intern_string
|
||||||
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
|
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
|
||||||
|
@ -63,6 +66,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
|
# Used by `_get_joined_hosts` to ensure only one thing mutates the cache
|
||||||
|
# at a time. Keyed by room_id.
|
||||||
|
self._joined_host_linearizer = Linearizer("_JoinedHostsCache")
|
||||||
|
|
||||||
# Is the current_state_events.membership up to date? Or is the
|
# Is the current_state_events.membership up to date? Or is the
|
||||||
# background update still running?
|
# background update still running?
|
||||||
self._current_state_events_membership_up_to_date = False
|
self._current_state_events_membership_up_to_date = False
|
||||||
|
@ -740,19 +747,82 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
|
|
||||||
@cached(num_args=2, max_entries=10000, iterable=True)
|
@cached(num_args=2, max_entries=10000, iterable=True)
|
||||||
async def _get_joined_hosts(
|
async def _get_joined_hosts(
|
||||||
self, room_id, state_group, current_state_ids, state_entry
|
self,
|
||||||
):
|
room_id: str,
|
||||||
# We don't use `state_group`, its there so that we can cache based
|
state_group: int,
|
||||||
# on it. However, its important that its never None, since two current_state's
|
current_state_ids: StateMap[str],
|
||||||
# with a state_group of None are likely to be different.
|
state_entry: "_StateCacheEntry",
|
||||||
|
) -> FrozenSet[str]:
|
||||||
|
# We don't use `state_group`, its there so that we can cache based on
|
||||||
|
# it. However, its important that its never None, since two
|
||||||
|
# current_state's with a state_group of None are likely to be different.
|
||||||
|
#
|
||||||
|
# The `state_group` must match the `state_entry.state_group` (if not None).
|
||||||
assert state_group is not None
|
assert state_group is not None
|
||||||
|
assert state_entry.state_group is None or state_entry.state_group == state_group
|
||||||
|
|
||||||
|
# We use a secondary cache of previous work to allow us to build up the
|
||||||
|
# joined hosts for the given state group based on previous state groups.
|
||||||
|
#
|
||||||
|
# We cache one object per room containing the results of the last state
|
||||||
|
# group we got joined hosts for. The idea is that generally
|
||||||
|
# `get_joined_hosts` is called with the "current" state group for the
|
||||||
|
# room, and so consecutive calls will be for consecutive state groups
|
||||||
|
# which point to the previous state group.
|
||||||
cache = await self._get_joined_hosts_cache(room_id)
|
cache = await self._get_joined_hosts_cache(room_id)
|
||||||
return await cache.get_destinations(state_entry)
|
|
||||||
|
# If the state group in the cache matches, we already have the data we need.
|
||||||
|
if state_entry.state_group == cache.state_group:
|
||||||
|
return frozenset(cache.hosts_to_joined_users)
|
||||||
|
|
||||||
|
# Since we'll mutate the cache we need to lock.
|
||||||
|
with (await self._joined_host_linearizer.queue(room_id)):
|
||||||
|
if state_entry.state_group == cache.state_group:
|
||||||
|
# Same state group, so nothing to do. We've already checked for
|
||||||
|
# this above, but the cache may have changed while waiting on
|
||||||
|
# the lock.
|
||||||
|
pass
|
||||||
|
elif state_entry.prev_group == cache.state_group:
|
||||||
|
# The cached work is for the previous state group, so we work out
|
||||||
|
# the delta.
|
||||||
|
for (typ, state_key), event_id in state_entry.delta_ids.items():
|
||||||
|
if typ != EventTypes.Member:
|
||||||
|
continue
|
||||||
|
|
||||||
|
host = intern_string(get_domain_from_id(state_key))
|
||||||
|
user_id = state_key
|
||||||
|
known_joins = cache.hosts_to_joined_users.setdefault(host, set())
|
||||||
|
|
||||||
|
event = await self.get_event(event_id)
|
||||||
|
if event.membership == Membership.JOIN:
|
||||||
|
known_joins.add(user_id)
|
||||||
|
else:
|
||||||
|
known_joins.discard(user_id)
|
||||||
|
|
||||||
|
if not known_joins:
|
||||||
|
cache.hosts_to_joined_users.pop(host, None)
|
||||||
|
else:
|
||||||
|
# The cache doesn't match the state group or prev state group,
|
||||||
|
# so we calculate the result from first principles.
|
||||||
|
joined_users = await self.get_joined_users_from_state(
|
||||||
|
room_id, state_entry
|
||||||
|
)
|
||||||
|
|
||||||
|
cache.hosts_to_joined_users = {}
|
||||||
|
for user_id in joined_users:
|
||||||
|
host = intern_string(get_domain_from_id(user_id))
|
||||||
|
cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
|
||||||
|
|
||||||
|
if state_entry.state_group:
|
||||||
|
cache.state_group = state_entry.state_group
|
||||||
|
else:
|
||||||
|
cache.state_group = object()
|
||||||
|
|
||||||
|
return frozenset(cache.hosts_to_joined_users)
|
||||||
|
|
||||||
@cached(max_entries=10000)
|
@cached(max_entries=10000)
|
||||||
def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache":
|
def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache":
|
||||||
return _JoinedHostsCache(self, room_id)
|
return _JoinedHostsCache()
|
||||||
|
|
||||||
@cached(num_args=2)
|
@cached(num_args=2)
|
||||||
async def did_forget(self, user_id: str, room_id: str) -> bool:
|
async def did_forget(self, user_id: str, room_id: str) -> bool:
|
||||||
|
@ -1062,71 +1132,18 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
|
||||||
await self.db_pool.runInteraction("forget_membership", f)
|
await self.db_pool.runInteraction("forget_membership", f)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True)
|
||||||
class _JoinedHostsCache:
|
class _JoinedHostsCache:
|
||||||
"""Cache for joined hosts in a room that is optimised to handle updates
|
"""The cached data used by the `_get_joined_hosts_cache`."""
|
||||||
via state deltas.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, store, room_id):
|
# Dict of host to the set of their users in the room at the state group.
|
||||||
self.store = store
|
hosts_to_joined_users = attr.ib(type=Dict[str, Set[str]], factory=dict)
|
||||||
self.room_id = room_id
|
|
||||||
|
|
||||||
self.hosts_to_joined_users = {}
|
# The state group `hosts_to_joined_users` is derived from. Will be an object
|
||||||
|
# if the instance is newly created or if the state is not based on a state
|
||||||
self.state_group = object()
|
# group. (An object is used as a sentinel value to ensure that it never is
|
||||||
|
# equal to anything else).
|
||||||
self.linearizer = Linearizer("_JoinedHostsCache")
|
state_group = attr.ib(type=Union[object, int], factory=object)
|
||||||
|
|
||||||
self._len = 0
|
|
||||||
|
|
||||||
async def get_destinations(self, state_entry: "_StateCacheEntry") -> Set[str]:
|
|
||||||
"""Get set of destinations for a state entry
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state_entry
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The destinations as a set.
|
|
||||||
"""
|
|
||||||
if state_entry.state_group == self.state_group:
|
|
||||||
return frozenset(self.hosts_to_joined_users)
|
|
||||||
|
|
||||||
with (await self.linearizer.queue(())):
|
|
||||||
if state_entry.state_group == self.state_group:
|
|
||||||
pass
|
|
||||||
elif state_entry.prev_group == self.state_group:
|
|
||||||
for (typ, state_key), event_id in state_entry.delta_ids.items():
|
|
||||||
if typ != EventTypes.Member:
|
|
||||||
continue
|
|
||||||
|
|
||||||
host = intern_string(get_domain_from_id(state_key))
|
|
||||||
user_id = state_key
|
|
||||||
known_joins = self.hosts_to_joined_users.setdefault(host, set())
|
|
||||||
|
|
||||||
event = await self.store.get_event(event_id)
|
|
||||||
if event.membership == Membership.JOIN:
|
|
||||||
known_joins.add(user_id)
|
|
||||||
else:
|
|
||||||
known_joins.discard(user_id)
|
|
||||||
|
|
||||||
if not known_joins:
|
|
||||||
self.hosts_to_joined_users.pop(host, None)
|
|
||||||
else:
|
|
||||||
joined_users = await self.store.get_joined_users_from_state(
|
|
||||||
self.room_id, state_entry
|
|
||||||
)
|
|
||||||
|
|
||||||
self.hosts_to_joined_users = {}
|
|
||||||
for user_id in joined_users:
|
|
||||||
host = intern_string(get_domain_from_id(user_id))
|
|
||||||
self.hosts_to_joined_users.setdefault(host, set()).add(user_id)
|
|
||||||
|
|
||||||
if state_entry.state_group:
|
|
||||||
self.state_group = state_entry.state_group
|
|
||||||
else:
|
|
||||||
self.state_group = object()
|
|
||||||
self._len = sum(len(v) for v in self.hosts_to_joined_users.values())
|
|
||||||
return frozenset(self.hosts_to_joined_users)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self._len
|
return sum(len(v) for v in self.hosts_to_joined_users.values())
|
||||||
|
|
Loading…
Reference in a new issue