Convert the roommember database to async/await. (#8070)

This commit is contained in:
Patrick Cloke 2020-08-12 12:14:34 -04:00 committed by GitHub
parent 5ecc8b5825
commit fbe930dad2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 116 additions and 242 deletions

1
changelog.d/8070.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View file

@ -58,7 +58,6 @@ class SQLBaseStore(metaclass=ABCMeta):
""" """
for host in {get_domain_from_id(u) for u in members_changed}: for host in {get_domain_from_id(u) for u in members_changed}:
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
self._attempt_to_invalidate_cache("was_host_joined", (room_id, host))
self._attempt_to_invalidate_cache("get_users_in_room", (room_id,)) self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) self._attempt_to_invalidate_cache("get_room_summary", (room_id,))

View file

@ -256,81 +256,6 @@ class PushRulesWorkerStore(
): ):
yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule) yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
@defer.inlineCallbacks
def bulk_get_push_rules_for_room(self, event, context):
state_group = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
result = yield self._bulk_get_push_rules_for_room(
event.room_id, state_group, current_state_ids, event=event
)
return result
@cachedInlineCallbacks(num_args=2, cache_context=True)
def _bulk_get_push_rules_for_room(
self, room_id, state_group, current_state_ids, cache_context, event=None
):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
# We also will want to generate notifs for other people in the room so
# their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room.
users_in_room = yield self._get_joined_users_from_context(
room_id,
state_group,
current_state_ids,
on_invalidate=cache_context.invalidate,
event=event,
)
# We ignore app service users for now. This is so that we don't fill
# up the `get_if_users_have_pushers` cache with AS entries that we
# know don't have pushers, nor even read receipts.
local_users_in_room = {
u
for u in users_in_room
if self.hs.is_mine_id(u)
and not self.get_if_app_services_interested_in_user(u)
}
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield self.get_if_users_have_pushers(
local_users_in_room, on_invalidate=cache_context.invalidate
)
user_ids = {
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
}
users_with_receipts = yield self.get_users_with_read_receipts_in_room(
room_id, on_invalidate=cache_context.invalidate
)
# any users with pushers must be ours: they have pushers
for uid in users_with_receipts:
if uid in local_users_in_room:
user_ids.add(uid)
rules_by_user = yield self.bulk_get_push_rules(
user_ids, on_invalidate=cache_context.invalidate
)
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
return rules_by_user
@cachedList( @cachedList(
cached_method_name="get_push_rules_enabled_for_user", cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids", list_name="user_ids",

View file

@ -15,11 +15,13 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Iterable, List, Set from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import ( from synapse.storage._base import (
@ -40,9 +42,12 @@ from synapse.storage.roommember import (
from synapse.types import Collection, get_domain_from_id from synapse.types import Collection, get_domain_from_id
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.state import _StateCacheEntry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -150,12 +155,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
) )
@cached(max_entries=100000, iterable=True) @cached(max_entries=100000, iterable=True)
def get_users_in_room(self, room_id): def get_users_in_room(self, room_id: str):
return self.db_pool.runInteraction( return self.db_pool.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id "get_users_in_room", self.get_users_in_room_txn, room_id
) )
def get_users_in_room_txn(self, txn, room_id): def get_users_in_room_txn(self, txn, room_id: str) -> List[str]:
# If we can assume current_state_events.membership is up to date # If we can assume current_state_events.membership is up to date
# then we can avoid a join, which is a Very Good Thing given how # then we can avoid a join, which is a Very Good Thing given how
# frequently this function gets called. # frequently this function gets called.
@ -178,11 +183,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return [r[0] for r in txn] return [r[0] for r in txn]
@cached(max_entries=100000) @cached(max_entries=100000)
def get_room_summary(self, room_id): def get_room_summary(self, room_id: str):
""" Get the details of a room roughly suitable for use by the room """ Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members. summary extension to /sync. Useful when lazy loading room members.
Args: Args:
room_id (str): The room ID to query room_id: The room ID to query
Returns: Returns:
Deferred[dict[str, MemberSummary]: Deferred[dict[str, MemberSummary]:
dict of membership states, pointing to a MemberSummary named tuple. dict of membership states, pointing to a MemberSummary named tuple.
@ -261,78 +266,59 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn) return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn)
def _get_user_counts_in_room_txn(self, txn, room_id):
"""
Get the user count in a room by membership.
Args:
room_id (str)
membership (Membership)
Returns:
Deferred[int]
"""
sql = """
SELECT m.membership, count(*) FROM room_memberships as m
INNER JOIN current_state_events as c USING(event_id)
WHERE c.type = 'm.room.member' AND c.room_id = ?
GROUP BY m.membership
"""
txn.execute(sql, (room_id,))
return {row[0]: row[1] for row in txn}
@cached() @cached()
def get_invited_rooms_for_local_user(self, user_id): def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]:
""" Get all the rooms the *local* user is invited to """Get all the rooms the *local* user is invited to.
Args: Args:
user_id (str): The user ID. user_id: The user ID.
Returns: Returns:
A deferred list of RoomsForUser. A awaitable list of RoomsForUser.
""" """
return self.get_rooms_for_local_user_where_membership_is( return self.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE] user_id, [Membership.INVITE]
) )
@defer.inlineCallbacks async def get_invite_for_local_user_in_room(
def get_invite_for_local_user_in_room(self, user_id, room_id): self, user_id: str, room_id: str
"""Gets the invite for the given *local* user and room ) -> Optional[RoomsForUser]:
"""Gets the invite for the given *local* user and room.
Args: Args:
user_id (str) user_id: The user ID to find the invite of.
room_id (str) room_id: The room to user was invited to.
Returns: Returns:
Deferred: Resolves to either a RoomsForUser or None if no invite was Either a RoomsForUser or None if no invite was found.
found.
""" """
invites = yield self.get_invited_rooms_for_local_user(user_id) invites = await self.get_invited_rooms_for_local_user(user_id)
for invite in invites: for invite in invites:
if invite.room_id == room_id: if invite.room_id == room_id:
return invite return invite
return None return None
@defer.inlineCallbacks async def get_rooms_for_local_user_where_membership_is(
def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list): self, user_id: str, membership_list: List[str]
""" Get all the rooms for this *local* user where the membership for this user ) -> Optional[List[RoomsForUser]]:
"""Get all the rooms for this *local* user where the membership for this user
matches one in the membership list. matches one in the membership list.
Filters out forgotten rooms. Filters out forgotten rooms.
Args: Args:
user_id (str): The user ID. user_id: The user ID.
membership_list (list): A list of synapse.api.constants.Membership membership_list: A list of synapse.api.constants.Membership
values which the user must be in. values which the user must be in.
Returns: Returns:
Deferred[list[RoomsForUser]] The RoomsForUser that the user matches the membership types.
""" """
if not membership_list: if not membership_list:
return defer.succeed(None) return None
rooms = yield self.db_pool.runInteraction( rooms = await self.db_pool.runInteraction(
"get_rooms_for_local_user_where_membership_is", "get_rooms_for_local_user_where_membership_is",
self._get_rooms_for_local_user_where_membership_is_txn, self._get_rooms_for_local_user_where_membership_is_txn,
user_id, user_id,
@ -340,12 +326,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
) )
# Now we filter out forgotten rooms # Now we filter out forgotten rooms
forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id) forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id)
return [room for room in rooms if room.room_id not in forgotten_rooms] return [room for room in rooms if room.room_id not in forgotten_rooms]
def _get_rooms_for_local_user_where_membership_is_txn( def _get_rooms_for_local_user_where_membership_is_txn(
self, txn, user_id, membership_list self, txn, user_id: str, membership_list: List[str]
): ) -> List[RoomsForUser]:
# Paranoia check. # Paranoia check.
if not self.hs.is_mine_id(user_id): if not self.hs.is_mine_id(user_id):
raise Exception( raise Exception(
@ -374,14 +360,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results return results
@cached(max_entries=500000, iterable=True) @cached(max_entries=500000, iterable=True)
def get_rooms_for_user_with_stream_ordering(self, user_id): def get_rooms_for_user_with_stream_ordering(self, user_id: str):
"""Returns a set of room_ids the user is currently joined to. """Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently If a remote user only returns rooms this server is currently
participating in. participating in.
Args: Args:
user_id (str) user_id
Returns: Returns:
Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
@ -394,7 +380,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
user_id, user_id,
) )
def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id): def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str):
# We use `current_state_events` here and not `local_current_membership` # We use `current_state_events` here and not `local_current_membership`
# as a) this gets called with remote users and b) this only gets called # as a) this gets called with remote users and b) this only gets called
# for rooms the server is participating in. # for rooms the server is participating in.
@ -458,37 +444,39 @@ class RoomMemberWorkerStore(EventsWorkerStore):
_get_users_server_still_shares_room_with_txn, _get_users_server_still_shares_room_with_txn,
) )
@defer.inlineCallbacks async def get_rooms_for_user(self, user_id: str, on_invalidate=None):
def get_rooms_for_user(self, user_id, on_invalidate=None):
"""Returns a set of room_ids the user is currently joined to. """Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently If a remote user only returns rooms this server is currently
participating in. participating in.
""" """
rooms = yield self.get_rooms_for_user_with_stream_ordering( rooms = await self.get_rooms_for_user_with_stream_ordering(
user_id, on_invalidate=on_invalidate user_id, on_invalidate=on_invalidate
) )
return frozenset(r.room_id for r in rooms) return frozenset(r.room_id for r in rooms)
@cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) @cached(max_entries=500000, cache_context=True, iterable=True)
def get_users_who_share_room_with_user(self, user_id, cache_context): async def get_users_who_share_room_with_user(
self, user_id: str, cache_context: _CacheContext
) -> Set[str]:
"""Returns the set of users who share a room with `user_id` """Returns the set of users who share a room with `user_id`
""" """
room_ids = yield self.get_rooms_for_user( room_ids = await self.get_rooms_for_user(
user_id, on_invalidate=cache_context.invalidate user_id, on_invalidate=cache_context.invalidate
) )
user_who_share_room = set() user_who_share_room = set()
for room_id in room_ids: for room_id in room_ids:
user_ids = yield self.get_users_in_room( user_ids = await self.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate room_id, on_invalidate=cache_context.invalidate
) )
user_who_share_room.update(user_ids) user_who_share_room.update(user_ids)
return user_who_share_room return user_who_share_room
@defer.inlineCallbacks async def get_joined_users_from_context(
def get_joined_users_from_context(self, event, context): self, event: EventBase, context: EventContext
):
state_group = context.state_group state_group = context.state_group
if not state_group: if not state_group:
# If state_group is None it means it has yet to be assigned a # If state_group is None it means it has yet to be assigned a
@ -497,14 +485,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object() # To do this we set the state_group to a new object as object() != object()
state_group = object() state_group = object()
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) current_state_ids = await context.get_current_state_ids()
result = yield self._get_joined_users_from_context( return await self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids, event=event, context=context event.room_id, state_group, current_state_ids, event=event, context=context
) )
return result
@defer.inlineCallbacks async def get_joined_users_from_state(self, room_id, state_entry):
def get_joined_users_from_state(self, room_id, state_entry):
state_group = state_entry.state_group state_group = state_entry.state_group
if not state_group: if not state_group:
# If state_group is None it means it has yet to be assigned a # If state_group is None it means it has yet to be assigned a
@ -514,16 +500,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
state_group = object() state_group = object()
with Measure(self._clock, "get_joined_users_from_state"): with Measure(self._clock, "get_joined_users_from_state"):
return ( return await self._get_joined_users_from_context(
yield self._get_joined_users_from_context( room_id, state_group, state_entry.state, context=state_entry
room_id, state_group, state_entry.state, context=state_entry
)
) )
@cachedInlineCallbacks( @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
num_args=2, cache_context=True, iterable=True, max_entries=100000 async def _get_joined_users_from_context(
)
def _get_joined_users_from_context(
self, self,
room_id, room_id,
state_group, state_group,
@ -535,7 +517,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# We don't use `state_group`, it's there so that we can cache based # We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states # on it. However, it's important that it's never None, since two current_states
# with a state_group of None are likely to be different. # with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None assert state_group is not None
users_in_room = {} users_in_room = {}
@ -588,7 +569,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
missing_member_event_ids.append(event_id) missing_member_event_ids.append(event_id)
if missing_member_event_ids: if missing_member_event_ids:
event_to_memberships = yield self._get_joined_profiles_from_event_ids( event_to_memberships = await self._get_joined_profiles_from_event_ids(
missing_member_event_ids missing_member_event_ids
) )
users_in_room.update((row for row in event_to_memberships.values() if row)) users_in_room.update((row for row in event_to_memberships.values() if row))
@ -612,12 +593,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
list_name="event_ids", list_name="event_ids",
inlineCallbacks=True, inlineCallbacks=True,
) )
def _get_joined_profiles_from_event_ids(self, event_ids): def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
"""For given set of member event_ids check if they point to a join """For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info. event and if so return the associated user and profile info.
Args: Args:
event_ids (Iterable[str]): The member event IDs to lookup event_ids: The member event IDs to lookup
Returns: Returns:
Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
@ -644,8 +625,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
for row in rows for row in rows
} }
@cachedInlineCallbacks(max_entries=10000) @cached(max_entries=10000)
def is_host_joined(self, room_id, host): async def is_host_joined(self, room_id: str, host: str) -> bool:
if "%" in host or "_" in host: if "%" in host or "_" in host:
raise Exception("Invalid host name") raise Exception("Invalid host name")
@ -664,7 +645,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# the returned user actually has the correct domain. # the returned user actually has the correct domain.
like_clause = "%:" + host like_clause = "%:" + host
rows = yield self.db_pool.execute( rows = await self.db_pool.execute(
"is_host_joined", None, sql, room_id, like_clause "is_host_joined", None, sql, room_id, like_clause
) )
@ -678,50 +659,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True return True
@cachedInlineCallbacks() async def get_joined_hosts(self, room_id: str, state_entry):
def was_host_joined(self, room_id, host):
"""Check whether the server is or ever was in the room.
Args:
room_id (str)
host (str)
Returns:
Deferred: Resolves to True if the host is/was in the room, otherwise
False.
"""
if "%" in host or "_" in host:
raise Exception("Invalid host name")
sql = """
SELECT user_id FROM room_memberships
WHERE room_id = ?
AND user_id LIKE ?
AND membership = 'join'
LIMIT 1
"""
# We do need to be careful to ensure that host doesn't have any wild cards
# in it, but we checked above for known ones and we'll check below that
# the returned user actually has the correct domain.
like_clause = "%:" + host
rows = yield self.db_pool.execute(
"was_host_joined", None, sql, room_id, like_clause
)
if not rows:
return False
user_id = rows[0][0]
if get_domain_from_id(user_id) != host:
# This can only happen if the host name has something funky in it
raise Exception("Invalid host name")
return True
@defer.inlineCallbacks
def get_joined_hosts(self, room_id, state_entry):
state_group = state_entry.state_group state_group = state_entry.state_group
if not state_group: if not state_group:
# If state_group is None it means it has yet to be assigned a # If state_group is None it means it has yet to be assigned a
@ -731,32 +669,28 @@ class RoomMemberWorkerStore(EventsWorkerStore):
state_group = object() state_group = object()
with Measure(self._clock, "get_joined_hosts"): with Measure(self._clock, "get_joined_hosts"):
return ( return await self._get_joined_hosts(
yield self._get_joined_hosts( room_id, state_group, state_entry.state, state_entry=state_entry
room_id, state_group, state_entry.state, state_entry=state_entry
)
) )
@cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) @cached(num_args=2, max_entries=10000, iterable=True)
# @defer.inlineCallbacks async def _get_joined_hosts(
def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry): self, room_id, state_group, current_state_ids, state_entry
):
# We don't use `state_group`, its there so that we can cache based # We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's # on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different. # with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None assert state_group is not None
cache = yield self._get_joined_hosts_cache(room_id) cache = await self._get_joined_hosts_cache(room_id)
joined_hosts = yield cache.get_destinations(state_entry) return await cache.get_destinations(state_entry)
return joined_hosts
@cached(max_entries=10000) @cached(max_entries=10000)
def _get_joined_hosts_cache(self, room_id): def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache":
return _JoinedHostsCache(self, room_id) return _JoinedHostsCache(self, room_id)
@cachedInlineCallbacks(num_args=2) @cached(num_args=2)
def did_forget(self, user_id, room_id): async def did_forget(self, user_id: str, room_id: str) -> bool:
"""Returns whether user_id has elected to discard history for room_id. """Returns whether user_id has elected to discard history for room_id.
Returns False if they have since re-joined.""" Returns False if they have since re-joined."""
@ -778,15 +712,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
rows = txn.fetchall() rows = txn.fetchall()
return rows[0][0] return rows[0][0]
count = yield self.db_pool.runInteraction("did_forget_membership", f) count = await self.db_pool.runInteraction("did_forget_membership", f)
return count == 0 return count == 0
@cached() @cached()
def get_forgotten_rooms_for_user(self, user_id): def get_forgotten_rooms_for_user(self, user_id: str):
"""Gets all rooms the user has forgotten. """Gets all rooms the user has forgotten.
Args: Args:
user_id (str) user_id
Returns: Returns:
Deferred[set[str]] Deferred[set[str]]
@ -819,18 +753,17 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
) )
@defer.inlineCallbacks async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]:
def get_rooms_user_has_been_in(self, user_id):
"""Get all rooms that the user has ever been in. """Get all rooms that the user has ever been in.
Args: Args:
user_id (str) user_id: The user ID to get the rooms of.
Returns: Returns:
Deferred[set[str]]: Set of room IDs. Set of room IDs.
""" """
room_ids = yield self.db_pool.simple_select_onecol( room_ids = await self.db_pool.simple_select_onecol(
table="room_memberships", table="room_memberships",
keyvalues={"membership": Membership.JOIN, "user_id": user_id}, keyvalues={"membership": Membership.JOIN, "user_id": user_id},
retcol="room_id", retcol="room_id",
@ -905,8 +838,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
where_clause="forgotten = 1", where_clause="forgotten = 1",
) )
@defer.inlineCallbacks async def _background_add_membership_profile(self, progress, batch_size):
def _background_add_membership_profile(self, progress, batch_size):
target_min_stream_id = progress.get( target_min_stream_id = progress.get(
"target_min_stream_id_inclusive", self._min_stream_order_on_start "target_min_stream_id_inclusive", self._min_stream_order_on_start
) )
@ -971,19 +903,18 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
return len(rows) return len(rows)
result = yield self.db_pool.runInteraction( result = await self.db_pool.runInteraction(
_MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
) )
if not result: if not result:
yield self.db_pool.updates._end_background_update( await self.db_pool.updates._end_background_update(
_MEMBERSHIP_PROFILE_UPDATE_NAME _MEMBERSHIP_PROFILE_UPDATE_NAME
) )
return result return result
@defer.inlineCallbacks async def _background_current_state_membership(self, progress, batch_size):
def _background_current_state_membership(self, progress, batch_size):
"""Update the new membership column on current_state_events. """Update the new membership column on current_state_events.
This works by iterating over all rooms in alphebetical order. This works by iterating over all rooms in alphebetical order.
@ -1029,14 +960,14 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
# string, which will compare before all room IDs correctly. # string, which will compare before all room IDs correctly.
last_processed_room = progress.get("last_processed_room", "") last_processed_room = progress.get("last_processed_room", "")
row_count, finished = yield self.db_pool.runInteraction( row_count, finished = await self.db_pool.runInteraction(
"_background_current_state_membership_update", "_background_current_state_membership_update",
_background_current_state_membership_txn, _background_current_state_membership_txn,
last_processed_room, last_processed_room,
) )
if finished: if finished:
yield self.db_pool.updates._end_background_update( await self.db_pool.updates._end_background_update(
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME
) )
@ -1047,7 +978,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomMemberStore, self).__init__(database, db_conn, hs) super(RoomMemberStore, self).__init__(database, db_conn, hs)
def forget(self, user_id, room_id): def forget(self, user_id: str, room_id: str):
"""Indicate that user_id wishes to discard history for room_id.""" """Indicate that user_id wishes to discard history for room_id."""
def f(txn): def f(txn):
@ -1088,17 +1019,19 @@ class _JoinedHostsCache(object):
self._len = 0 self._len = 0
@defer.inlineCallbacks async def get_destinations(self, state_entry: "_StateCacheEntry") -> Set[str]:
def get_destinations(self, state_entry):
"""Get set of destinations for a state entry """Get set of destinations for a state entry
Args: Args:
state_entry(synapse.state._StateCacheEntry) state_entry
Returns:
The destinations as a set.
""" """
if state_entry.state_group == self.state_group: if state_entry.state_group == self.state_group:
return frozenset(self.hosts_to_joined_users) return frozenset(self.hosts_to_joined_users)
with (yield self.linearizer.queue(())): with (await self.linearizer.queue(())):
if state_entry.state_group == self.state_group: if state_entry.state_group == self.state_group:
pass pass
elif state_entry.prev_group == self.state_group: elif state_entry.prev_group == self.state_group:
@ -1110,7 +1043,7 @@ class _JoinedHostsCache(object):
user_id = state_key user_id = state_key
known_joins = self.hosts_to_joined_users.setdefault(host, set()) known_joins = self.hosts_to_joined_users.setdefault(host, set())
event = yield self.store.get_event(event_id) event = await self.store.get_event(event_id)
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
known_joins.add(user_id) known_joins.add(user_id)
else: else:
@ -1119,7 +1052,7 @@ class _JoinedHostsCache(object):
if not known_joins: if not known_joins:
self.hosts_to_joined_users.pop(host, None) self.hosts_to_joined_users.pop(host, None)
else: else:
joined_users = yield self.store.get_joined_users_from_state( joined_users = await self.store.get_joined_users_from_state(
self.room_id, state_entry self.room_id, state_entry
) )

View file

@ -1,3 +1,18 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock from mock import Mock
from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
@ -10,6 +25,7 @@ from synapse.util.retryutils import NotRetryingDestination
from tests import unittest from tests import unittest
from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver
from tests.test_utils import make_awaitable
class MessageAcceptTests(unittest.HomeserverTestCase): class MessageAcceptTests(unittest.HomeserverTestCase):
@ -173,7 +189,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register a mock on the store so that the incoming update doesn't fail because # Register a mock on the store so that the incoming update doesn't fail because
# we don't share a room with the user. # we don't share a room with the user.
store = self.homeserver.get_datastore() store = self.homeserver.get_datastore()
store.get_rooms_for_user = Mock(return_value=succeed(["!someroom:test"])) store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"]))
# Manually inject a fake device list update. We need this update to include at # Manually inject a fake device list update. We need this update to include at
# least one prev_id so that the user's device list will need to be retried. # least one prev_id so that the user's device list will need to be retried.