mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-22 09:35:45 +03:00
Convert misc database code to async (#8087)
This commit is contained in:
parent
7bdf9828d5
commit
894dae74fe
11 changed files with 39 additions and 64 deletions
1
changelog.d/8087.misc
Normal file
1
changelog.d/8087.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
|
@ -18,8 +18,6 @@ from typing import Optional
|
|||
|
||||
from canonicaljson import json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
|
||||
from . import engines
|
||||
|
@ -308,9 +306,8 @@ class BackgroundUpdater(object):
|
|||
update_name (str): Name of update
|
||||
"""
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def noop_update(progress, batch_size):
|
||||
yield self._end_background_update(update_name)
|
||||
async def noop_update(progress, batch_size):
|
||||
await self._end_background_update(update_name)
|
||||
return 1
|
||||
|
||||
self.register_background_update_handler(update_name, noop_update)
|
||||
|
@ -409,12 +406,11 @@ class BackgroundUpdater(object):
|
|||
else:
|
||||
runner = create_index_sqlite
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def updater(progress, batch_size):
|
||||
async def updater(progress, batch_size):
|
||||
if runner is not None:
|
||||
logger.info("Adding index %s to %s", index_name, table)
|
||||
yield self.db_pool.runWithConnection(runner)
|
||||
yield self._end_background_update(update_name)
|
||||
await self.db_pool.runWithConnection(runner)
|
||||
await self._end_background_update(update_name)
|
||||
return 1
|
||||
|
||||
self.register_background_update_handler(update_name, updater)
|
||||
|
|
|
@ -671,10 +671,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
@cachedList(
|
||||
cached_method_name="get_device_list_last_stream_id_for_remote",
|
||||
list_name="user_ids",
|
||||
inlineCallbacks=True,
|
||||
)
|
||||
def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="device_lists_remote_extremeties",
|
||||
column="user_id",
|
||||
iterable=user_ids,
|
||||
|
|
|
@ -21,7 +21,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
|||
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -86,18 +86,17 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
self._rotate_delay = 3
|
||||
self._rotate_count = 10000
|
||||
|
||||
@cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
|
||||
def get_unread_event_push_actions_by_room_for_user(
|
||||
@cached(num_args=3, tree=True, max_entries=5000)
|
||||
async def get_unread_event_push_actions_by_room_for_user(
|
||||
self, room_id, user_id, last_read_event_id
|
||||
):
|
||||
ret = yield self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_unread_event_push_actions_by_room",
|
||||
self._get_unread_counts_by_receipt_txn,
|
||||
room_id,
|
||||
user_id,
|
||||
last_read_event_id,
|
||||
)
|
||||
return ret
|
||||
|
||||
def _get_unread_counts_by_receipt_txn(
|
||||
self, txn, room_id, user_id, last_read_event_id
|
||||
|
|
|
@ -130,13 +130,10 @@ class PresenceStore(SQLBaseStore):
|
|||
raise NotImplementedError()
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="_get_presence_for_user",
|
||||
list_name="user_ids",
|
||||
num_args=1,
|
||||
inlineCallbacks=True,
|
||||
cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
|
||||
)
|
||||
def get_presence_for_users(self, user_ids):
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
async def get_presence_for_users(self, user_ids):
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="presence_stream",
|
||||
column="user_id",
|
||||
iterable=user_ids,
|
||||
|
|
|
@ -170,18 +170,15 @@ class PushRulesWorkerStore(
|
|||
)
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="get_push_rules_for_user",
|
||||
list_name="user_ids",
|
||||
num_args=1,
|
||||
inlineCallbacks=True,
|
||||
cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
|
||||
)
|
||||
def bulk_get_push_rules(self, user_ids):
|
||||
async def bulk_get_push_rules(self, user_ids):
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
results = {user_id: [] for user_id in user_ids}
|
||||
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="push_rules",
|
||||
column="user_name",
|
||||
iterable=user_ids,
|
||||
|
@ -194,7 +191,7 @@ class PushRulesWorkerStore(
|
|||
for row in rows:
|
||||
results.setdefault(row["user_name"], []).append(row)
|
||||
|
||||
enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
|
||||
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
|
||||
|
||||
for user_id, rules in results.items():
|
||||
use_new_defaults = user_id in self._users_new_default_push_rules
|
||||
|
@ -260,15 +257,14 @@ class PushRulesWorkerStore(
|
|||
cached_method_name="get_push_rules_enabled_for_user",
|
||||
list_name="user_ids",
|
||||
num_args=1,
|
||||
inlineCallbacks=True,
|
||||
)
|
||||
def bulk_get_push_rules_enabled(self, user_ids):
|
||||
async def bulk_get_push_rules_enabled(self, user_ids):
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
results = {user_id: {} for user_id in user_ids}
|
||||
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="push_rules_enable",
|
||||
column="user_name",
|
||||
iterable=user_ids,
|
||||
|
|
|
@ -170,13 +170,10 @@ class PusherWorkerStore(SQLBaseStore):
|
|||
raise NotImplementedError()
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="get_if_user_has_pusher",
|
||||
list_name="user_ids",
|
||||
num_args=1,
|
||||
inlineCallbacks=True,
|
||||
cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
|
||||
)
|
||||
def get_if_users_have_pushers(self, user_ids):
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
async def get_if_users_have_pushers(self, user_ids):
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="pushers",
|
||||
column="user_name",
|
||||
iterable=user_ids,
|
||||
|
|
|
@ -212,9 +212,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
cached_method_name="_get_linearized_receipts_for_room",
|
||||
list_name="room_ids",
|
||||
num_args=3,
|
||||
inlineCallbacks=True,
|
||||
)
|
||||
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
|
||||
async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
|
||||
if not room_ids:
|
||||
return {}
|
||||
|
||||
|
@ -243,7 +242,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
txn_results = yield self.db_pool.runInteraction(
|
||||
txn_results = await self.db_pool.runInteraction(
|
||||
"_get_linearized_receipts_for_rooms", f
|
||||
)
|
||||
|
||||
|
|
|
@ -17,8 +17,6 @@
|
|||
import logging
|
||||
from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
|
@ -92,8 +90,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
lambda: self._known_servers_count,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _count_known_servers(self):
|
||||
async def _count_known_servers(self):
|
||||
"""
|
||||
Count the servers that this server knows about.
|
||||
|
||||
|
@ -121,7 +118,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
txn.execute(query)
|
||||
return list(txn)[0][0]
|
||||
|
||||
count = yield self.db_pool.runInteraction("get_known_servers", _transact)
|
||||
count = await self.db_pool.runInteraction("get_known_servers", _transact)
|
||||
|
||||
# We always know about ourselves, even if we have nothing in
|
||||
# room_memberships (for example, the server is new).
|
||||
|
@ -589,11 +586,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
raise NotImplementedError()
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="_get_joined_profile_from_event_id",
|
||||
list_name="event_ids",
|
||||
inlineCallbacks=True,
|
||||
cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
|
||||
)
|
||||
def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
|
||||
async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
|
||||
"""For given set of member event_ids check if they point to a join
|
||||
event and if so return the associated user and profile info.
|
||||
|
||||
|
@ -601,11 +596,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
event_ids: The member event IDs to lookup
|
||||
|
||||
Returns:
|
||||
Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
|
||||
dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
|
||||
to `user_id` and ProfileInfo (or None if not join event).
|
||||
"""
|
||||
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="room_memberships",
|
||||
column="event_id",
|
||||
iterable=event_ids,
|
||||
|
|
|
@ -273,12 +273,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
cached_method_name="_get_state_group_for_event",
|
||||
list_name="event_ids",
|
||||
num_args=1,
|
||||
inlineCallbacks=True,
|
||||
)
|
||||
def _get_state_group_for_events(self, event_ids):
|
||||
async def _get_state_group_for_events(self, event_ids):
|
||||
"""Returns mapping event_id -> state_group
|
||||
"""
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="event_to_state_groups",
|
||||
column="event_id",
|
||||
iterable=event_ids,
|
||||
|
|
|
@ -38,10 +38,8 @@ class UserErasureWorkerStore(SQLBaseStore):
|
|||
desc="is_user_erased",
|
||||
).addCallback(operator.truth)
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True
|
||||
)
|
||||
def are_users_erased(self, user_ids):
|
||||
@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
|
||||
async def are_users_erased(self, user_ids):
|
||||
"""
|
||||
Checks which users in a list have requested erasure
|
||||
|
||||
|
@ -49,14 +47,14 @@ class UserErasureWorkerStore(SQLBaseStore):
|
|||
user_ids (iterable[str]): full user id to check
|
||||
|
||||
Returns:
|
||||
Deferred[dict[str, bool]]:
|
||||
dict[str, bool]:
|
||||
for each user, whether the user has requested erasure.
|
||||
"""
|
||||
# this serves the dual purpose of (a) making sure we can do len and
|
||||
# iterate it multiple times, and (b) avoiding duplicates.
|
||||
user_ids = tuple(set(user_ids))
|
||||
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="erased_users",
|
||||
column="user_id",
|
||||
iterable=user_ids,
|
||||
|
@ -65,8 +63,7 @@ class UserErasureWorkerStore(SQLBaseStore):
|
|||
)
|
||||
erased_users = {row["user_id"] for row in rows}
|
||||
|
||||
res = {u: u in erased_users for u in user_ids}
|
||||
return res
|
||||
return {u: u in erased_users for u in user_ids}
|
||||
|
||||
|
||||
class UserErasureStore(UserErasureWorkerStore):
|
||||
|
|
Loading…
Reference in a new issue