Convert account data, device inbox, and censor events databases to async/await (#8063)

This commit is contained in:
Patrick Cloke 2020-08-12 09:29:06 -04:00 committed by GitHub
parent a3a59bab7b
commit d68e10f308
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 99 additions and 87 deletions

1
changelog.d/8063.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View file

@ -16,15 +16,16 @@
import abc import abc
import logging import logging
from typing import List, Tuple from typing import List, Optional, Tuple
from twisted.internet import defer from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -97,13 +98,15 @@ class AccountDataWorkerStore(SQLBaseStore):
"get_account_data_for_user", get_account_data_for_user_txn "get_account_data_for_user", get_account_data_for_user_txn
) )
@cachedInlineCallbacks(num_args=2, max_entries=5000) @cached(num_args=2, max_entries=5000)
def get_global_account_data_by_type_for_user(self, data_type, user_id): async def get_global_account_data_by_type_for_user(
self, data_type: str, user_id: str
) -> Optional[JsonDict]:
""" """
Returns: Returns:
Deferred: A dict The account data.
""" """
result = yield self.db_pool.simple_select_one_onecol( result = await self.db_pool.simple_select_one_onecol(
table="account_data", table="account_data",
keyvalues={"user_id": user_id, "account_data_type": data_type}, keyvalues={"user_id": user_id, "account_data_type": data_type},
retcol="content", retcol="content",
@ -280,9 +283,11 @@ class AccountDataWorkerStore(SQLBaseStore):
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
) )
@cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) @cached(num_args=2, cache_context=True, max_entries=5000)
def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): async def is_ignored_by(
ignored_account_data = yield self.get_global_account_data_by_type_for_user( self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext
) -> bool:
ignored_account_data = await self.get_global_account_data_by_type_for_user(
"m.ignored_user_list", "m.ignored_user_list",
ignorer_user_id, ignorer_user_id,
on_invalidate=cache_context.invalidate, on_invalidate=cache_context.invalidate,
@ -307,24 +312,27 @@ class AccountDataStore(AccountDataWorkerStore):
super(AccountDataStore, self).__init__(database, db_conn, hs) super(AccountDataStore, self).__init__(database, db_conn, hs)
def get_max_account_data_stream_id(self): def get_max_account_data_stream_id(self) -> int:
"""Get the current max stream id for the private user data stream """Get the current max stream id for the private user data stream
Returns: Returns:
A deferred int. The maximum stream ID.
""" """
return self._account_data_id_gen.get_current_token() return self._account_data_id_gen.get_current_token()
@defer.inlineCallbacks async def add_account_data_to_room(
def add_account_data_to_room(self, user_id, room_id, account_data_type, content): self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
) -> int:
"""Add some account_data to a room for a user. """Add some account_data to a room for a user.
Args: Args:
user_id(str): The user to add a tag for. user_id: The user to add a tag for.
room_id(str): The room to add a tag for. room_id: The room to add a tag for.
account_data_type(str): The type of account_data to add. account_data_type: The type of account_data to add.
content(dict): A json object to associate with the tag. content: A json object to associate with the tag.
Returns: Returns:
A deferred that completes once the account_data has been added. The maximum stream ID.
""" """
content_json = json_encoder.encode(content) content_json = json_encoder.encode(content)
@ -332,7 +340,7 @@ class AccountDataStore(AccountDataWorkerStore):
# no need to lock here as room_account_data has a unique constraint # no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will # on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict. # retry if there is a conflict.
yield self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
desc="add_room_account_data", desc="add_room_account_data",
table="room_account_data", table="room_account_data",
keyvalues={ keyvalues={
@ -350,7 +358,7 @@ class AccountDataStore(AccountDataWorkerStore):
# doesn't sound any worse than the whole update getting lost, # doesn't sound any worse than the whole update getting lost,
# which is what would happen if we combined the two into one # which is what would happen if we combined the two into one
# transaction. # transaction.
yield self._update_max_stream_id(next_id) await self._update_max_stream_id(next_id)
self._account_data_stream_cache.entity_has_changed(user_id, next_id) self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,)) self.get_account_data_for_user.invalidate((user_id,))
@ -359,18 +367,20 @@ class AccountDataStore(AccountDataWorkerStore):
(user_id, room_id, account_data_type), content (user_id, room_id, account_data_type), content
) )
result = self._account_data_id_gen.get_current_token() return self._account_data_id_gen.get_current_token()
return result
@defer.inlineCallbacks async def add_account_data_for_user(
def add_account_data_for_user(self, user_id, account_data_type, content): self, user_id: str, account_data_type: str, content: JsonDict
) -> int:
"""Add some account_data to a room for a user. """Add some account_data to a room for a user.
Args: Args:
user_id(str): The user to add a tag for. user_id: The user to add a tag for.
account_data_type(str): The type of account_data to add. account_data_type: The type of account_data to add.
content(dict): A json object to associate with the tag. content: A json object to associate with the tag.
Returns: Returns:
A deferred that completes once the account_data has been added. The maximum stream ID.
""" """
content_json = json_encoder.encode(content) content_json = json_encoder.encode(content)
@ -378,7 +388,7 @@ class AccountDataStore(AccountDataWorkerStore):
# no need to lock here as account_data has a unique constraint on # no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if # (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict. # there is a conflict.
yield self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
desc="add_user_account_data", desc="add_user_account_data",
table="account_data", table="account_data",
keyvalues={"user_id": user_id, "account_data_type": account_data_type}, keyvalues={"user_id": user_id, "account_data_type": account_data_type},
@ -396,7 +406,7 @@ class AccountDataStore(AccountDataWorkerStore):
# Note: This is only here for backwards compat to allow admins to # Note: This is only here for backwards compat to allow admins to
# roll back to a previous Synapse version. Next time we update the # roll back to a previous Synapse version. Next time we update the
# database version we can remove this table. # database version we can remove this table.
yield self._update_max_stream_id(next_id) await self._update_max_stream_id(next_id)
self._account_data_stream_cache.entity_has_changed(user_id, next_id) self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,)) self.get_account_data_for_user.invalidate((user_id,))
@ -404,14 +414,13 @@ class AccountDataStore(AccountDataWorkerStore):
(account_data_type, user_id) (account_data_type, user_id)
) )
result = self._account_data_id_gen.get_current_token() return self._account_data_id_gen.get_current_token()
return result
def _update_max_stream_id(self, next_id): def _update_max_stream_id(self, next_id: int):
"""Update the max stream_id """Update the max stream_id
Args: Args:
next_id(int): The the revision to advance to. next_id: The the revision to advance to.
""" """
# Note: This is only here for backwards compat to allow admins to # Note: This is only here for backwards compat to allow admins to

View file

@ -16,8 +16,6 @@
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from twisted.internet import defer
from synapse.events.utils import prune_event_dict from synapse.events.utils import prune_event_dict
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
@ -148,17 +146,16 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
updatevalues={"json": pruned_json}, updatevalues={"json": pruned_json},
) )
@defer.inlineCallbacks async def expire_event(self, event_id: str) -> None:
def expire_event(self, event_id):
"""Retrieve and expire an event that has expired, and delete its associated """Retrieve and expire an event that has expired, and delete its associated
expiry timestamp. If the event can't be retrieved, delete its associated expiry timestamp. If the event can't be retrieved, delete its associated
timestamp so we don't try to expire it again in the future. timestamp so we don't try to expire it again in the future.
Args: Args:
event_id (str): The ID of the event to delete. event_id: The ID of the event to delete.
""" """
# Try to retrieve the event's content from the database or the event cache. # Try to retrieve the event's content from the database or the event cache.
event = yield self.get_event(event_id) event = await self.get_event(event_id)
def delete_expired_event_txn(txn): def delete_expired_event_txn(txn):
# Delete the expiry timestamp associated with this event from the database. # Delete the expiry timestamp associated with this event from the database.
@ -193,7 +190,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
txn, "_get_event_cache", (event.event_id,) txn, "_get_event_cache", (event.event_id,)
) )
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"delete_expired_event", delete_expired_event_txn "delete_expired_event", delete_expired_event_txn
) )

View file

@ -16,8 +16,6 @@
import logging import logging
from typing import List, Tuple from typing import List, Tuple
from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
@ -31,24 +29,31 @@ class DeviceInboxWorkerStore(SQLBaseStore):
def get_to_device_stream_token(self): def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token() return self._device_inbox_id_gen.get_current_token()
def get_new_messages_for_device( async def get_new_messages_for_device(
self, user_id, device_id, last_stream_id, current_stream_id, limit=100 self,
): user_id: str,
device_id: str,
last_stream_id: int,
current_stream_id: int,
limit: int = 100,
) -> Tuple[List[dict], int]:
""" """
Args: Args:
user_id(str): The recipient user_id. user_id: The recipient user_id.
device_id(str): The recipient device_id. device_id: The recipient device_id.
current_stream_id(int): The current position of the to device last_stream_id: The last stream ID checked.
current_stream_id: The current position of the to device
message stream. message stream.
limit: The maximum number of messages to retrieve.
Returns: Returns:
Deferred ([dict], int): List of messages for the device and where A list of messages for the device and where in the stream the messages got to.
in the stream the messages got to.
""" """
has_changed = self._device_inbox_stream_cache.has_entity_changed( has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_stream_id user_id, last_stream_id
) )
if not has_changed: if not has_changed:
return defer.succeed(([], current_stream_id)) return ([], current_stream_id)
def get_new_messages_for_device_txn(txn): def get_new_messages_for_device_txn(txn):
sql = ( sql = (
@ -69,20 +74,22 @@ class DeviceInboxWorkerStore(SQLBaseStore):
stream_pos = current_stream_id stream_pos = current_stream_id
return messages, stream_pos return messages, stream_pos
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_new_messages_for_device", get_new_messages_for_device_txn "get_new_messages_for_device", get_new_messages_for_device_txn
) )
@trace @trace
@defer.inlineCallbacks async def delete_messages_for_device(
def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): self, user_id: str, device_id: str, up_to_stream_id: int
) -> int:
""" """
Args: Args:
user_id(str): The recipient user_id. user_id: The recipient user_id.
device_id(str): The recipient device_id. device_id: The recipient device_id.
up_to_stream_id(int): Where to delete messages up to. up_to_stream_id: Where to delete messages up to.
Returns: Returns:
A deferred that resolves to the number of messages deleted. The number of messages deleted.
""" """
# If we have cached the last stream id we've deleted up to, we can # If we have cached the last stream id we've deleted up to, we can
# check if there is likely to be anything that needs deleting # check if there is likely to be anything that needs deleting
@ -109,7 +116,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, device_id, up_to_stream_id)) txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount return txn.rowcount
count = yield self.db_pool.runInteraction( count = await self.db_pool.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn "delete_messages_for_device", delete_messages_for_device_txn
) )
@ -128,9 +135,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
return count return count
@trace @trace
def get_new_device_msgs_for_remote( async def get_new_device_msgs_for_remote(
self, destination, last_stream_id, current_stream_id, limit self, destination, last_stream_id, current_stream_id, limit
): ) -> Tuple[List[dict], int]:
""" """
Args: Args:
destination(str): The name of the remote server. destination(str): The name of the remote server.
@ -139,8 +146,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
current_stream_id(int|long): The current position of the device current_stream_id(int|long): The current position of the device
message stream. message stream.
Returns: Returns:
Deferred ([dict], int|long): List of messages for the device and where A list of messages for the device and where in the stream the messages got to.
in the stream the messages got to.
""" """
set_tag("destination", destination) set_tag("destination", destination)
@ -153,11 +159,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
) )
if not has_changed or last_stream_id == current_stream_id: if not has_changed or last_stream_id == current_stream_id:
log_kv({"message": "No new messages in stream"}) log_kv({"message": "No new messages in stream"})
return defer.succeed(([], current_stream_id)) return ([], current_stream_id)
if limit <= 0: if limit <= 0:
# This can happen if we run out of room for EDUs in the transaction. # This can happen if we run out of room for EDUs in the transaction.
return defer.succeed(([], last_stream_id)) return ([], last_stream_id)
@trace @trace
def get_new_messages_for_remote_destination_txn(txn): def get_new_messages_for_remote_destination_txn(txn):
@ -178,7 +184,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
stream_pos = current_stream_id stream_pos = current_stream_id
return messages, stream_pos return messages, stream_pos
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_new_device_msgs_for_remote", "get_new_device_msgs_for_remote",
get_new_messages_for_remote_destination_txn, get_new_messages_for_remote_destination_txn,
) )
@ -290,16 +296,15 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
) )
@defer.inlineCallbacks async def _background_drop_index_device_inbox(self, progress, batch_size):
def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn): def reindex_txn(conn):
txn = conn.cursor() txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close() txn.close()
yield self.db_pool.runWithConnection(reindex_txn) await self.db_pool.runWithConnection(reindex_txn)
yield self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
return 1 return 1
@ -320,21 +325,21 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
) )
@trace @trace
@defer.inlineCallbacks async def add_messages_to_device_inbox(
def add_messages_to_device_inbox( self,
self, local_messages_by_user_then_device, remote_messages_by_destination local_messages_by_user_then_device: dict,
): remote_messages_by_destination: dict,
) -> int:
"""Used to send messages from this server. """Used to send messages from this server.
Args: Args:
sender_user_id(str): The ID of the user sending these messages. local_messages_by_user_and_device:
local_messages_by_user_and_device(dict):
Dictionary of user_id to device_id to message. Dictionary of user_id to device_id to message.
remote_messages_by_destination(dict): remote_messages_by_destination:
Dictionary of destination server_name to the EDU JSON to send. Dictionary of destination server_name to the EDU JSON to send.
Returns: Returns:
A deferred stream_id that resolves when the messages have been The new stream_id.
inserted.
""" """
def add_messages_txn(txn, now_ms, stream_id): def add_messages_txn(txn, now_ms, stream_id):
@ -359,7 +364,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
with self._device_inbox_id_gen.get_next() as stream_id: with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec() now_ms = self.clock.time_msec()
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
) )
for user_id in local_messages_by_user_then_device.keys(): for user_id in local_messages_by_user_then_device.keys():
@ -371,10 +376,9 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
return self._device_inbox_id_gen.get_current_token() return self._device_inbox_id_gen.get_current_token()
@defer.inlineCallbacks async def add_messages_from_remote_to_device_inbox(
def add_messages_from_remote_to_device_inbox( self, origin: str, message_id: str, local_messages_by_user_then_device: dict
self, origin, message_id, local_messages_by_user_then_device ) -> int:
):
def add_messages_txn(txn, now_ms, stream_id): def add_messages_txn(txn, now_ms, stream_id):
# Check if we've already inserted a matching message_id for that # Check if we've already inserted a matching message_id for that
# origin. This can happen if the origin doesn't receive our # origin. This can happen if the origin doesn't receive our
@ -409,7 +413,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
with self._device_inbox_id_gen.get_next() as stream_id: with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec() now_ms = self.clock.time_msec()
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox", "add_messages_from_remote_to_device_inbox",
add_messages_txn, add_messages_txn,
now_ms, now_ms,

View file

@ -24,6 +24,7 @@ from synapse.api.errors import AuthError
from synapse.types import UserID from synapse.types import UserID
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import override_config from tests.unittest import override_config
from tests.utils import register_federation_servlets from tests.utils import register_federation_servlets
@ -151,7 +152,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_current_state_deltas.return_value = (0, None) self.datastore.get_current_state_deltas.return_value = (0, None)
self.datastore.get_to_device_stream_token = lambda: 0 self.datastore.get_to_device_stream_token = lambda: 0
self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: defer.succeed( self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
([], 0) ([], 0)
) )
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None