mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-28 23:20:09 +03:00
Convert the roommember database to async/await. (#8070)
This commit is contained in:
parent
5ecc8b5825
commit
fbe930dad2
5 changed files with 116 additions and 242 deletions
1
changelog.d/8070.misc
Normal file
1
changelog.d/8070.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Convert various parts of the codebase to async/await.
|
|
@ -58,7 +58,6 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||||
"""
|
"""
|
||||||
for host in {get_domain_from_id(u) for u in members_changed}:
|
for host in {get_domain_from_id(u) for u in members_changed}:
|
||||||
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
|
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
|
||||||
self._attempt_to_invalidate_cache("was_host_joined", (room_id, host))
|
|
||||||
|
|
||||||
self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
|
self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
|
||||||
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
|
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
|
||||||
|
|
|
@ -256,81 +256,6 @@ class PushRulesWorkerStore(
|
||||||
):
|
):
|
||||||
yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
|
yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def bulk_get_push_rules_for_room(self, event, context):
|
|
||||||
state_group = context.state_group
|
|
||||||
if not state_group:
|
|
||||||
# If state_group is None it means it has yet to be assigned a
|
|
||||||
# state group, i.e. we need to make sure that calls with a state_group
|
|
||||||
# of None don't hit previous cached calls with a None state_group.
|
|
||||||
# To do this we set the state_group to a new object as object() != object()
|
|
||||||
state_group = object()
|
|
||||||
|
|
||||||
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
|
|
||||||
result = yield self._bulk_get_push_rules_for_room(
|
|
||||||
event.room_id, state_group, current_state_ids, event=event
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=2, cache_context=True)
|
|
||||||
def _bulk_get_push_rules_for_room(
|
|
||||||
self, room_id, state_group, current_state_ids, cache_context, event=None
|
|
||||||
):
|
|
||||||
# We don't use `state_group`, its there so that we can cache based
|
|
||||||
# on it. However, its important that its never None, since two current_state's
|
|
||||||
# with a state_group of None are likely to be different.
|
|
||||||
# See bulk_get_push_rules_for_room for how we work around this.
|
|
||||||
assert state_group is not None
|
|
||||||
|
|
||||||
# We also will want to generate notifs for other people in the room so
|
|
||||||
# their unread countss are correct in the event stream, but to avoid
|
|
||||||
# generating them for bot / AS users etc, we only do so for people who've
|
|
||||||
# sent a read receipt into the room.
|
|
||||||
|
|
||||||
users_in_room = yield self._get_joined_users_from_context(
|
|
||||||
room_id,
|
|
||||||
state_group,
|
|
||||||
current_state_ids,
|
|
||||||
on_invalidate=cache_context.invalidate,
|
|
||||||
event=event,
|
|
||||||
)
|
|
||||||
|
|
||||||
# We ignore app service users for now. This is so that we don't fill
|
|
||||||
# up the `get_if_users_have_pushers` cache with AS entries that we
|
|
||||||
# know don't have pushers, nor even read receipts.
|
|
||||||
local_users_in_room = {
|
|
||||||
u
|
|
||||||
for u in users_in_room
|
|
||||||
if self.hs.is_mine_id(u)
|
|
||||||
and not self.get_if_app_services_interested_in_user(u)
|
|
||||||
}
|
|
||||||
|
|
||||||
# users in the room who have pushers need to get push rules run because
|
|
||||||
# that's how their pushers work
|
|
||||||
if_users_with_pushers = yield self.get_if_users_have_pushers(
|
|
||||||
local_users_in_room, on_invalidate=cache_context.invalidate
|
|
||||||
)
|
|
||||||
user_ids = {
|
|
||||||
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
|
|
||||||
}
|
|
||||||
|
|
||||||
users_with_receipts = yield self.get_users_with_read_receipts_in_room(
|
|
||||||
room_id, on_invalidate=cache_context.invalidate
|
|
||||||
)
|
|
||||||
|
|
||||||
# any users with pushers must be ours: they have pushers
|
|
||||||
for uid in users_with_receipts:
|
|
||||||
if uid in local_users_in_room:
|
|
||||||
user_ids.add(uid)
|
|
||||||
|
|
||||||
rules_by_user = yield self.bulk_get_push_rules(
|
|
||||||
user_ids, on_invalidate=cache_context.invalidate
|
|
||||||
)
|
|
||||||
|
|
||||||
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
|
|
||||||
|
|
||||||
return rules_by_user
|
|
||||||
|
|
||||||
@cachedList(
|
@cachedList(
|
||||||
cached_method_name="get_push_rules_enabled_for_user",
|
cached_method_name="get_push_rules_enabled_for_user",
|
||||||
list_name="user_ids",
|
list_name="user_ids",
|
||||||
|
|
|
@ -15,11 +15,13 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Iterable, List, Set
|
from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
|
from synapse.events import EventBase
|
||||||
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.metrics import LaterGauge
|
from synapse.metrics import LaterGauge
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage._base import (
|
from synapse.storage._base import (
|
||||||
|
@ -40,9 +42,12 @@ from synapse.storage.roommember import (
|
||||||
from synapse.types import Collection, get_domain_from_id
|
from synapse.types import Collection, get_domain_from_id
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.caches import intern_string
|
from synapse.util.caches import intern_string
|
||||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
|
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.state import _StateCacheEntry
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -150,12 +155,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached(max_entries=100000, iterable=True)
|
@cached(max_entries=100000, iterable=True)
|
||||||
def get_users_in_room(self, room_id):
|
def get_users_in_room(self, room_id: str):
|
||||||
return self.db_pool.runInteraction(
|
return self.db_pool.runInteraction(
|
||||||
"get_users_in_room", self.get_users_in_room_txn, room_id
|
"get_users_in_room", self.get_users_in_room_txn, room_id
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_users_in_room_txn(self, txn, room_id):
|
def get_users_in_room_txn(self, txn, room_id: str) -> List[str]:
|
||||||
# If we can assume current_state_events.membership is up to date
|
# If we can assume current_state_events.membership is up to date
|
||||||
# then we can avoid a join, which is a Very Good Thing given how
|
# then we can avoid a join, which is a Very Good Thing given how
|
||||||
# frequently this function gets called.
|
# frequently this function gets called.
|
||||||
|
@ -178,11 +183,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
return [r[0] for r in txn]
|
return [r[0] for r in txn]
|
||||||
|
|
||||||
@cached(max_entries=100000)
|
@cached(max_entries=100000)
|
||||||
def get_room_summary(self, room_id):
|
def get_room_summary(self, room_id: str):
|
||||||
""" Get the details of a room roughly suitable for use by the room
|
""" Get the details of a room roughly suitable for use by the room
|
||||||
summary extension to /sync. Useful when lazy loading room members.
|
summary extension to /sync. Useful when lazy loading room members.
|
||||||
Args:
|
Args:
|
||||||
room_id (str): The room ID to query
|
room_id: The room ID to query
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict[str, MemberSummary]:
|
Deferred[dict[str, MemberSummary]:
|
||||||
dict of membership states, pointing to a MemberSummary named tuple.
|
dict of membership states, pointing to a MemberSummary named tuple.
|
||||||
|
@ -261,78 +266,59 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
|
|
||||||
return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn)
|
return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn)
|
||||||
|
|
||||||
def _get_user_counts_in_room_txn(self, txn, room_id):
|
|
||||||
"""
|
|
||||||
Get the user count in a room by membership.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
room_id (str)
|
|
||||||
membership (Membership)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred[int]
|
|
||||||
"""
|
|
||||||
sql = """
|
|
||||||
SELECT m.membership, count(*) FROM room_memberships as m
|
|
||||||
INNER JOIN current_state_events as c USING(event_id)
|
|
||||||
WHERE c.type = 'm.room.member' AND c.room_id = ?
|
|
||||||
GROUP BY m.membership
|
|
||||||
"""
|
|
||||||
|
|
||||||
txn.execute(sql, (room_id,))
|
|
||||||
return {row[0]: row[1] for row in txn}
|
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
def get_invited_rooms_for_local_user(self, user_id):
|
def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]:
|
||||||
""" Get all the rooms the *local* user is invited to
|
"""Get all the rooms the *local* user is invited to.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The user ID.
|
user_id: The user ID.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A deferred list of RoomsForUser.
|
A awaitable list of RoomsForUser.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.get_rooms_for_local_user_where_membership_is(
|
return self.get_rooms_for_local_user_where_membership_is(
|
||||||
user_id, [Membership.INVITE]
|
user_id, [Membership.INVITE]
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_invite_for_local_user_in_room(
|
||||||
def get_invite_for_local_user_in_room(self, user_id, room_id):
|
self, user_id: str, room_id: str
|
||||||
"""Gets the invite for the given *local* user and room
|
) -> Optional[RoomsForUser]:
|
||||||
|
"""Gets the invite for the given *local* user and room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str)
|
user_id: The user ID to find the invite of.
|
||||||
room_id (str)
|
room_id: The room to user was invited to.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: Resolves to either a RoomsForUser or None if no invite was
|
Either a RoomsForUser or None if no invite was found.
|
||||||
found.
|
|
||||||
"""
|
"""
|
||||||
invites = yield self.get_invited_rooms_for_local_user(user_id)
|
invites = await self.get_invited_rooms_for_local_user(user_id)
|
||||||
for invite in invites:
|
for invite in invites:
|
||||||
if invite.room_id == room_id:
|
if invite.room_id == room_id:
|
||||||
return invite
|
return invite
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_rooms_for_local_user_where_membership_is(
|
||||||
def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list):
|
self, user_id: str, membership_list: List[str]
|
||||||
|
) -> Optional[List[RoomsForUser]]:
|
||||||
"""Get all the rooms for this *local* user where the membership for this user
|
"""Get all the rooms for this *local* user where the membership for this user
|
||||||
matches one in the membership list.
|
matches one in the membership list.
|
||||||
|
|
||||||
Filters out forgotten rooms.
|
Filters out forgotten rooms.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The user ID.
|
user_id: The user ID.
|
||||||
membership_list (list): A list of synapse.api.constants.Membership
|
membership_list: A list of synapse.api.constants.Membership
|
||||||
values which the user must be in.
|
values which the user must be in.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[list[RoomsForUser]]
|
The RoomsForUser that the user matches the membership types.
|
||||||
"""
|
"""
|
||||||
if not membership_list:
|
if not membership_list:
|
||||||
return defer.succeed(None)
|
return None
|
||||||
|
|
||||||
rooms = yield self.db_pool.runInteraction(
|
rooms = await self.db_pool.runInteraction(
|
||||||
"get_rooms_for_local_user_where_membership_is",
|
"get_rooms_for_local_user_where_membership_is",
|
||||||
self._get_rooms_for_local_user_where_membership_is_txn,
|
self._get_rooms_for_local_user_where_membership_is_txn,
|
||||||
user_id,
|
user_id,
|
||||||
|
@ -340,12 +326,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now we filter out forgotten rooms
|
# Now we filter out forgotten rooms
|
||||||
forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id)
|
forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id)
|
||||||
return [room for room in rooms if room.room_id not in forgotten_rooms]
|
return [room for room in rooms if room.room_id not in forgotten_rooms]
|
||||||
|
|
||||||
def _get_rooms_for_local_user_where_membership_is_txn(
|
def _get_rooms_for_local_user_where_membership_is_txn(
|
||||||
self, txn, user_id, membership_list
|
self, txn, user_id: str, membership_list: List[str]
|
||||||
):
|
) -> List[RoomsForUser]:
|
||||||
# Paranoia check.
|
# Paranoia check.
|
||||||
if not self.hs.is_mine_id(user_id):
|
if not self.hs.is_mine_id(user_id):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -374,14 +360,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@cached(max_entries=500000, iterable=True)
|
@cached(max_entries=500000, iterable=True)
|
||||||
def get_rooms_for_user_with_stream_ordering(self, user_id):
|
def get_rooms_for_user_with_stream_ordering(self, user_id: str):
|
||||||
"""Returns a set of room_ids the user is currently joined to.
|
"""Returns a set of room_ids the user is currently joined to.
|
||||||
|
|
||||||
If a remote user only returns rooms this server is currently
|
If a remote user only returns rooms this server is currently
|
||||||
participating in.
|
participating in.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str)
|
user_id
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
|
Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
|
||||||
|
@ -394,7 +380,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id):
|
def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str):
|
||||||
# We use `current_state_events` here and not `local_current_membership`
|
# We use `current_state_events` here and not `local_current_membership`
|
||||||
# as a) this gets called with remote users and b) this only gets called
|
# as a) this gets called with remote users and b) this only gets called
|
||||||
# for rooms the server is participating in.
|
# for rooms the server is participating in.
|
||||||
|
@ -458,37 +444,39 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
_get_users_server_still_shares_room_with_txn,
|
_get_users_server_still_shares_room_with_txn,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_rooms_for_user(self, user_id: str, on_invalidate=None):
|
||||||
def get_rooms_for_user(self, user_id, on_invalidate=None):
|
|
||||||
"""Returns a set of room_ids the user is currently joined to.
|
"""Returns a set of room_ids the user is currently joined to.
|
||||||
|
|
||||||
If a remote user only returns rooms this server is currently
|
If a remote user only returns rooms this server is currently
|
||||||
participating in.
|
participating in.
|
||||||
"""
|
"""
|
||||||
rooms = yield self.get_rooms_for_user_with_stream_ordering(
|
rooms = await self.get_rooms_for_user_with_stream_ordering(
|
||||||
user_id, on_invalidate=on_invalidate
|
user_id, on_invalidate=on_invalidate
|
||||||
)
|
)
|
||||||
return frozenset(r.room_id for r in rooms)
|
return frozenset(r.room_id for r in rooms)
|
||||||
|
|
||||||
@cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
|
@cached(max_entries=500000, cache_context=True, iterable=True)
|
||||||
def get_users_who_share_room_with_user(self, user_id, cache_context):
|
async def get_users_who_share_room_with_user(
|
||||||
|
self, user_id: str, cache_context: _CacheContext
|
||||||
|
) -> Set[str]:
|
||||||
"""Returns the set of users who share a room with `user_id`
|
"""Returns the set of users who share a room with `user_id`
|
||||||
"""
|
"""
|
||||||
room_ids = yield self.get_rooms_for_user(
|
room_ids = await self.get_rooms_for_user(
|
||||||
user_id, on_invalidate=cache_context.invalidate
|
user_id, on_invalidate=cache_context.invalidate
|
||||||
)
|
)
|
||||||
|
|
||||||
user_who_share_room = set()
|
user_who_share_room = set()
|
||||||
for room_id in room_ids:
|
for room_id in room_ids:
|
||||||
user_ids = yield self.get_users_in_room(
|
user_ids = await self.get_users_in_room(
|
||||||
room_id, on_invalidate=cache_context.invalidate
|
room_id, on_invalidate=cache_context.invalidate
|
||||||
)
|
)
|
||||||
user_who_share_room.update(user_ids)
|
user_who_share_room.update(user_ids)
|
||||||
|
|
||||||
return user_who_share_room
|
return user_who_share_room
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_joined_users_from_context(
|
||||||
def get_joined_users_from_context(self, event, context):
|
self, event: EventBase, context: EventContext
|
||||||
|
):
|
||||||
state_group = context.state_group
|
state_group = context.state_group
|
||||||
if not state_group:
|
if not state_group:
|
||||||
# If state_group is None it means it has yet to be assigned a
|
# If state_group is None it means it has yet to be assigned a
|
||||||
|
@ -497,14 +485,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
# To do this we set the state_group to a new object as object() != object()
|
# To do this we set the state_group to a new object as object() != object()
|
||||||
state_group = object()
|
state_group = object()
|
||||||
|
|
||||||
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
|
current_state_ids = await context.get_current_state_ids()
|
||||||
result = yield self._get_joined_users_from_context(
|
return await self._get_joined_users_from_context(
|
||||||
event.room_id, state_group, current_state_ids, event=event, context=context
|
event.room_id, state_group, current_state_ids, event=event, context=context
|
||||||
)
|
)
|
||||||
return result
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_joined_users_from_state(self, room_id, state_entry):
|
||||||
def get_joined_users_from_state(self, room_id, state_entry):
|
|
||||||
state_group = state_entry.state_group
|
state_group = state_entry.state_group
|
||||||
if not state_group:
|
if not state_group:
|
||||||
# If state_group is None it means it has yet to be assigned a
|
# If state_group is None it means it has yet to be assigned a
|
||||||
|
@ -514,16 +500,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
state_group = object()
|
state_group = object()
|
||||||
|
|
||||||
with Measure(self._clock, "get_joined_users_from_state"):
|
with Measure(self._clock, "get_joined_users_from_state"):
|
||||||
return (
|
return await self._get_joined_users_from_context(
|
||||||
yield self._get_joined_users_from_context(
|
|
||||||
room_id, state_group, state_entry.state, context=state_entry
|
room_id, state_group, state_entry.state, context=state_entry
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
@cachedInlineCallbacks(
|
@cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
|
||||||
num_args=2, cache_context=True, iterable=True, max_entries=100000
|
async def _get_joined_users_from_context(
|
||||||
)
|
|
||||||
def _get_joined_users_from_context(
|
|
||||||
self,
|
self,
|
||||||
room_id,
|
room_id,
|
||||||
state_group,
|
state_group,
|
||||||
|
@ -535,7 +517,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
# We don't use `state_group`, it's there so that we can cache based
|
# We don't use `state_group`, it's there so that we can cache based
|
||||||
# on it. However, it's important that it's never None, since two current_states
|
# on it. However, it's important that it's never None, since two current_states
|
||||||
# with a state_group of None are likely to be different.
|
# with a state_group of None are likely to be different.
|
||||||
# See bulk_get_push_rules_for_room for how we work around this.
|
|
||||||
assert state_group is not None
|
assert state_group is not None
|
||||||
|
|
||||||
users_in_room = {}
|
users_in_room = {}
|
||||||
|
@ -588,7 +569,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
missing_member_event_ids.append(event_id)
|
missing_member_event_ids.append(event_id)
|
||||||
|
|
||||||
if missing_member_event_ids:
|
if missing_member_event_ids:
|
||||||
event_to_memberships = yield self._get_joined_profiles_from_event_ids(
|
event_to_memberships = await self._get_joined_profiles_from_event_ids(
|
||||||
missing_member_event_ids
|
missing_member_event_ids
|
||||||
)
|
)
|
||||||
users_in_room.update((row for row in event_to_memberships.values() if row))
|
users_in_room.update((row for row in event_to_memberships.values() if row))
|
||||||
|
@ -612,12 +593,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
list_name="event_ids",
|
list_name="event_ids",
|
||||||
inlineCallbacks=True,
|
inlineCallbacks=True,
|
||||||
)
|
)
|
||||||
def _get_joined_profiles_from_event_ids(self, event_ids):
|
def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
|
||||||
"""For given set of member event_ids check if they point to a join
|
"""For given set of member event_ids check if they point to a join
|
||||||
event and if so return the associated user and profile info.
|
event and if so return the associated user and profile info.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event_ids (Iterable[str]): The member event IDs to lookup
|
event_ids: The member event IDs to lookup
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
|
Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
|
||||||
|
@ -644,8 +625,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
for row in rows
|
for row in rows
|
||||||
}
|
}
|
||||||
|
|
||||||
@cachedInlineCallbacks(max_entries=10000)
|
@cached(max_entries=10000)
|
||||||
def is_host_joined(self, room_id, host):
|
async def is_host_joined(self, room_id: str, host: str) -> bool:
|
||||||
if "%" in host or "_" in host:
|
if "%" in host or "_" in host:
|
||||||
raise Exception("Invalid host name")
|
raise Exception("Invalid host name")
|
||||||
|
|
||||||
|
@ -664,7 +645,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
# the returned user actually has the correct domain.
|
# the returned user actually has the correct domain.
|
||||||
like_clause = "%:" + host
|
like_clause = "%:" + host
|
||||||
|
|
||||||
rows = yield self.db_pool.execute(
|
rows = await self.db_pool.execute(
|
||||||
"is_host_joined", None, sql, room_id, like_clause
|
"is_host_joined", None, sql, room_id, like_clause
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -678,50 +659,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@cachedInlineCallbacks()
|
async def get_joined_hosts(self, room_id: str, state_entry):
|
||||||
def was_host_joined(self, room_id, host):
|
|
||||||
"""Check whether the server is or ever was in the room.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
room_id (str)
|
|
||||||
host (str)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred: Resolves to True if the host is/was in the room, otherwise
|
|
||||||
False.
|
|
||||||
"""
|
|
||||||
if "%" in host or "_" in host:
|
|
||||||
raise Exception("Invalid host name")
|
|
||||||
|
|
||||||
sql = """
|
|
||||||
SELECT user_id FROM room_memberships
|
|
||||||
WHERE room_id = ?
|
|
||||||
AND user_id LIKE ?
|
|
||||||
AND membership = 'join'
|
|
||||||
LIMIT 1
|
|
||||||
"""
|
|
||||||
|
|
||||||
# We do need to be careful to ensure that host doesn't have any wild cards
|
|
||||||
# in it, but we checked above for known ones and we'll check below that
|
|
||||||
# the returned user actually has the correct domain.
|
|
||||||
like_clause = "%:" + host
|
|
||||||
|
|
||||||
rows = yield self.db_pool.execute(
|
|
||||||
"was_host_joined", None, sql, room_id, like_clause
|
|
||||||
)
|
|
||||||
|
|
||||||
if not rows:
|
|
||||||
return False
|
|
||||||
|
|
||||||
user_id = rows[0][0]
|
|
||||||
if get_domain_from_id(user_id) != host:
|
|
||||||
# This can only happen if the host name has something funky in it
|
|
||||||
raise Exception("Invalid host name")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_joined_hosts(self, room_id, state_entry):
|
|
||||||
state_group = state_entry.state_group
|
state_group = state_entry.state_group
|
||||||
if not state_group:
|
if not state_group:
|
||||||
# If state_group is None it means it has yet to be assigned a
|
# If state_group is None it means it has yet to be assigned a
|
||||||
|
@ -731,32 +669,28 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
state_group = object()
|
state_group = object()
|
||||||
|
|
||||||
with Measure(self._clock, "get_joined_hosts"):
|
with Measure(self._clock, "get_joined_hosts"):
|
||||||
return (
|
return await self._get_joined_hosts(
|
||||||
yield self._get_joined_hosts(
|
|
||||||
room_id, state_group, state_entry.state, state_entry=state_entry
|
room_id, state_group, state_entry.state, state_entry=state_entry
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
|
@cached(num_args=2, max_entries=10000, iterable=True)
|
||||||
# @defer.inlineCallbacks
|
async def _get_joined_hosts(
|
||||||
def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
|
self, room_id, state_group, current_state_ids, state_entry
|
||||||
|
):
|
||||||
# We don't use `state_group`, its there so that we can cache based
|
# We don't use `state_group`, its there so that we can cache based
|
||||||
# on it. However, its important that its never None, since two current_state's
|
# on it. However, its important that its never None, since two current_state's
|
||||||
# with a state_group of None are likely to be different.
|
# with a state_group of None are likely to be different.
|
||||||
# See bulk_get_push_rules_for_room for how we work around this.
|
|
||||||
assert state_group is not None
|
assert state_group is not None
|
||||||
|
|
||||||
cache = yield self._get_joined_hosts_cache(room_id)
|
cache = await self._get_joined_hosts_cache(room_id)
|
||||||
joined_hosts = yield cache.get_destinations(state_entry)
|
return await cache.get_destinations(state_entry)
|
||||||
|
|
||||||
return joined_hosts
|
|
||||||
|
|
||||||
@cached(max_entries=10000)
|
@cached(max_entries=10000)
|
||||||
def _get_joined_hosts_cache(self, room_id):
|
def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache":
|
||||||
return _JoinedHostsCache(self, room_id)
|
return _JoinedHostsCache(self, room_id)
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=2)
|
@cached(num_args=2)
|
||||||
def did_forget(self, user_id, room_id):
|
async def did_forget(self, user_id: str, room_id: str) -> bool:
|
||||||
"""Returns whether user_id has elected to discard history for room_id.
|
"""Returns whether user_id has elected to discard history for room_id.
|
||||||
|
|
||||||
Returns False if they have since re-joined."""
|
Returns False if they have since re-joined."""
|
||||||
|
@ -778,15 +712,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
rows = txn.fetchall()
|
rows = txn.fetchall()
|
||||||
return rows[0][0]
|
return rows[0][0]
|
||||||
|
|
||||||
count = yield self.db_pool.runInteraction("did_forget_membership", f)
|
count = await self.db_pool.runInteraction("did_forget_membership", f)
|
||||||
return count == 0
|
return count == 0
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
def get_forgotten_rooms_for_user(self, user_id):
|
def get_forgotten_rooms_for_user(self, user_id: str):
|
||||||
"""Gets all rooms the user has forgotten.
|
"""Gets all rooms the user has forgotten.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str)
|
user_id
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[set[str]]
|
Deferred[set[str]]
|
||||||
|
@ -819,18 +753,17 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
|
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]:
|
||||||
def get_rooms_user_has_been_in(self, user_id):
|
|
||||||
"""Get all rooms that the user has ever been in.
|
"""Get all rooms that the user has ever been in.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str)
|
user_id: The user ID to get the rooms of.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[set[str]]: Set of room IDs.
|
Set of room IDs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
room_ids = yield self.db_pool.simple_select_onecol(
|
room_ids = await self.db_pool.simple_select_onecol(
|
||||||
table="room_memberships",
|
table="room_memberships",
|
||||||
keyvalues={"membership": Membership.JOIN, "user_id": user_id},
|
keyvalues={"membership": Membership.JOIN, "user_id": user_id},
|
||||||
retcol="room_id",
|
retcol="room_id",
|
||||||
|
@ -905,8 +838,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
||||||
where_clause="forgotten = 1",
|
where_clause="forgotten = 1",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _background_add_membership_profile(self, progress, batch_size):
|
||||||
def _background_add_membership_profile(self, progress, batch_size):
|
|
||||||
target_min_stream_id = progress.get(
|
target_min_stream_id = progress.get(
|
||||||
"target_min_stream_id_inclusive", self._min_stream_order_on_start
|
"target_min_stream_id_inclusive", self._min_stream_order_on_start
|
||||||
)
|
)
|
||||||
|
@ -971,19 +903,18 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
||||||
|
|
||||||
return len(rows)
|
return len(rows)
|
||||||
|
|
||||||
result = yield self.db_pool.runInteraction(
|
result = await self.db_pool.runInteraction(
|
||||||
_MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
|
_MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
yield self.db_pool.updates._end_background_update(
|
await self.db_pool.updates._end_background_update(
|
||||||
_MEMBERSHIP_PROFILE_UPDATE_NAME
|
_MEMBERSHIP_PROFILE_UPDATE_NAME
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _background_current_state_membership(self, progress, batch_size):
|
||||||
def _background_current_state_membership(self, progress, batch_size):
|
|
||||||
"""Update the new membership column on current_state_events.
|
"""Update the new membership column on current_state_events.
|
||||||
|
|
||||||
This works by iterating over all rooms in alphebetical order.
|
This works by iterating over all rooms in alphebetical order.
|
||||||
|
@ -1029,14 +960,14 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
||||||
# string, which will compare before all room IDs correctly.
|
# string, which will compare before all room IDs correctly.
|
||||||
last_processed_room = progress.get("last_processed_room", "")
|
last_processed_room = progress.get("last_processed_room", "")
|
||||||
|
|
||||||
row_count, finished = yield self.db_pool.runInteraction(
|
row_count, finished = await self.db_pool.runInteraction(
|
||||||
"_background_current_state_membership_update",
|
"_background_current_state_membership_update",
|
||||||
_background_current_state_membership_txn,
|
_background_current_state_membership_txn,
|
||||||
last_processed_room,
|
last_processed_room,
|
||||||
)
|
)
|
||||||
|
|
||||||
if finished:
|
if finished:
|
||||||
yield self.db_pool.updates._end_background_update(
|
await self.db_pool.updates._end_background_update(
|
||||||
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME
|
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1047,7 +978,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||||
super(RoomMemberStore, self).__init__(database, db_conn, hs)
|
super(RoomMemberStore, self).__init__(database, db_conn, hs)
|
||||||
|
|
||||||
def forget(self, user_id, room_id):
|
def forget(self, user_id: str, room_id: str):
|
||||||
"""Indicate that user_id wishes to discard history for room_id."""
|
"""Indicate that user_id wishes to discard history for room_id."""
|
||||||
|
|
||||||
def f(txn):
|
def f(txn):
|
||||||
|
@ -1088,17 +1019,19 @@ class _JoinedHostsCache(object):
|
||||||
|
|
||||||
self._len = 0
|
self._len = 0
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_destinations(self, state_entry: "_StateCacheEntry") -> Set[str]:
|
||||||
def get_destinations(self, state_entry):
|
|
||||||
"""Get set of destinations for a state entry
|
"""Get set of destinations for a state entry
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state_entry(synapse.state._StateCacheEntry)
|
state_entry
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The destinations as a set.
|
||||||
"""
|
"""
|
||||||
if state_entry.state_group == self.state_group:
|
if state_entry.state_group == self.state_group:
|
||||||
return frozenset(self.hosts_to_joined_users)
|
return frozenset(self.hosts_to_joined_users)
|
||||||
|
|
||||||
with (yield self.linearizer.queue(())):
|
with (await self.linearizer.queue(())):
|
||||||
if state_entry.state_group == self.state_group:
|
if state_entry.state_group == self.state_group:
|
||||||
pass
|
pass
|
||||||
elif state_entry.prev_group == self.state_group:
|
elif state_entry.prev_group == self.state_group:
|
||||||
|
@ -1110,7 +1043,7 @@ class _JoinedHostsCache(object):
|
||||||
user_id = state_key
|
user_id = state_key
|
||||||
known_joins = self.hosts_to_joined_users.setdefault(host, set())
|
known_joins = self.hosts_to_joined_users.setdefault(host, set())
|
||||||
|
|
||||||
event = yield self.store.get_event(event_id)
|
event = await self.store.get_event(event_id)
|
||||||
if event.membership == Membership.JOIN:
|
if event.membership == Membership.JOIN:
|
||||||
known_joins.add(user_id)
|
known_joins.add(user_id)
|
||||||
else:
|
else:
|
||||||
|
@ -1119,7 +1052,7 @@ class _JoinedHostsCache(object):
|
||||||
if not known_joins:
|
if not known_joins:
|
||||||
self.hosts_to_joined_users.pop(host, None)
|
self.hosts_to_joined_users.pop(host, None)
|
||||||
else:
|
else:
|
||||||
joined_users = yield self.store.get_joined_users_from_state(
|
joined_users = await self.store.get_joined_users_from_state(
|
||||||
self.room_id, state_entry
|
self.room_id, state_entry
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
|
from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
|
||||||
|
@ -10,6 +25,7 @@ from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver
|
from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver
|
||||||
|
from tests.test_utils import make_awaitable
|
||||||
|
|
||||||
|
|
||||||
class MessageAcceptTests(unittest.HomeserverTestCase):
|
class MessageAcceptTests(unittest.HomeserverTestCase):
|
||||||
|
@ -173,7 +189,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
|
||||||
# Register a mock on the store so that the incoming update doesn't fail because
|
# Register a mock on the store so that the incoming update doesn't fail because
|
||||||
# we don't share a room with the user.
|
# we don't share a room with the user.
|
||||||
store = self.homeserver.get_datastore()
|
store = self.homeserver.get_datastore()
|
||||||
store.get_rooms_for_user = Mock(return_value=succeed(["!someroom:test"]))
|
store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"]))
|
||||||
|
|
||||||
# Manually inject a fake device list update. We need this update to include at
|
# Manually inject a fake device list update. We need this update to include at
|
||||||
# least one prev_id so that the user's device list will need to be retried.
|
# least one prev_id so that the user's device list will need to be retried.
|
||||||
|
|
Loading…
Reference in a new issue