mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-28 07:00:51 +03:00
Delete device messages asynchronously and in staged batches (#16240)
This commit is contained in:
parent
1e571cd664
commit
4f1840a88a
13 changed files with 154 additions and 37 deletions
1
changelog.d/16240.misc
Normal file
1
changelog.d/16240.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Delete device messages asynchronously and in staged batches using the task scheduler.
|
|
@ -43,9 +43,12 @@ from synapse.metrics.background_process_metrics import (
|
|||
)
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
JsonMapping,
|
||||
ScheduledTask,
|
||||
StrCollection,
|
||||
StreamKeyType,
|
||||
StreamToken,
|
||||
TaskStatus,
|
||||
UserID,
|
||||
get_domain_from_id,
|
||||
get_verify_key_from_cross_signing_key,
|
||||
|
@ -62,6 +65,7 @@ if TYPE_CHECKING:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DELETE_DEVICE_MSGS_TASK_NAME = "delete_device_messages"
|
||||
MAX_DEVICE_DISPLAY_NAME_LEN = 100
|
||||
DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000
|
||||
|
||||
|
@ -78,6 +82,7 @@ class DeviceWorkerHandler:
|
|||
self._appservice_handler = hs.get_application_service_handler()
|
||||
self._state_storage = hs.get_storage_controllers().state
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._event_sources = hs.get_event_sources()
|
||||
self.server_name = hs.hostname
|
||||
self._msc3852_enabled = hs.config.experimental.msc3852_enabled
|
||||
self._query_appservices_for_keys = (
|
||||
|
@ -386,6 +391,7 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
self._account_data_handler = hs.get_account_data_handler()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self.db_pool = hs.get_datastores().main.db_pool
|
||||
self._task_scheduler = hs.get_task_scheduler()
|
||||
|
||||
self.device_list_updater = DeviceListUpdater(hs, self)
|
||||
|
||||
|
@ -419,6 +425,10 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
self._delete_stale_devices,
|
||||
)
|
||||
|
||||
self._task_scheduler.register_action(
|
||||
self._delete_device_messages, DELETE_DEVICE_MSGS_TASK_NAME
|
||||
)
|
||||
|
||||
def _check_device_name_length(self, name: Optional[str]) -> None:
|
||||
"""
|
||||
Checks whether a device name is longer than the maximum allowed length.
|
||||
|
@ -530,6 +540,7 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
user_id: The user to delete devices from.
|
||||
device_ids: The list of device IDs to delete
|
||||
"""
|
||||
to_device_stream_id = self._event_sources.get_current_token().to_device_key
|
||||
|
||||
try:
|
||||
await self.store.delete_devices(user_id, device_ids)
|
||||
|
@ -559,12 +570,49 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
f"org.matrix.msc3890.local_notification_settings.{device_id}",
|
||||
)
|
||||
|
||||
# Delete device messages asynchronously and in batches using the task scheduler
|
||||
await self._task_scheduler.schedule_task(
|
||||
DELETE_DEVICE_MSGS_TASK_NAME,
|
||||
resource_id=device_id,
|
||||
params={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"up_to_stream_id": to_device_stream_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Pushers are deleted after `delete_access_tokens_for_user` is called so that
|
||||
# modules using `on_logged_out` hook can use them if needed.
|
||||
await self.hs.get_pusherpool().remove_pushers_by_devices(user_id, device_ids)
|
||||
|
||||
await self.notify_device_update(user_id, device_ids)
|
||||
|
||||
DEVICE_MSGS_DELETE_BATCH_LIMIT = 100
|
||||
|
||||
async def _delete_device_messages(
|
||||
self,
|
||||
task: ScheduledTask,
|
||||
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
|
||||
"""Scheduler task to delete device messages in batch of `DEVICE_MSGS_DELETE_BATCH_LIMIT`."""
|
||||
assert task.params is not None
|
||||
user_id = task.params["user_id"]
|
||||
device_id = task.params["device_id"]
|
||||
up_to_stream_id = task.params["up_to_stream_id"]
|
||||
|
||||
res = await self.store.delete_messages_for_device(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
up_to_stream_id=up_to_stream_id,
|
||||
limit=DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT,
|
||||
)
|
||||
|
||||
if res < DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT:
|
||||
return TaskStatus.COMPLETE, None, None
|
||||
else:
|
||||
# There is probably still device messages to be deleted, let's keep the task active and it will be run
|
||||
# again in a subsequent scheduler loop run (probably the next one, if not too many tasks are running).
|
||||
return TaskStatus.ACTIVE, None, None
|
||||
|
||||
async def update_device(self, user_id: str, device_id: str, content: dict) -> None:
|
||||
"""Update the given device
|
||||
|
||||
|
|
|
@ -183,6 +183,7 @@ class BasePresenceHandler(abc.ABC):
|
|||
writer"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastores().main
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
|
@ -473,8 +474,6 @@ class _NullContextManager(ContextManager[None]):
|
|||
class WorkerPresenceHandler(BasePresenceHandler):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self.hs = hs
|
||||
|
||||
self._presence_writer_instance = hs.config.worker.writers.presence[0]
|
||||
|
||||
# Route presence EDUs to the right worker
|
||||
|
@ -738,7 +737,6 @@ class WorkerPresenceHandler(BasePresenceHandler):
|
|||
class PresenceHandler(BasePresenceHandler):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self.hs = hs
|
||||
self.wheel_timer: WheelTimer[str] = WheelTimer()
|
||||
self.notifier = hs.get_notifier()
|
||||
|
||||
|
|
|
@ -40,6 +40,7 @@ from synapse.api.filtering import FilterCollection
|
|||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.events import EventBase
|
||||
from synapse.handlers.device import DELETE_DEVICE_MSGS_TASK_NAME
|
||||
from synapse.handlers.relations import BundledAggregations
|
||||
from synapse.logging import issue9533_logger
|
||||
from synapse.logging.context import current_context
|
||||
|
@ -268,6 +269,7 @@ class SyncHandler:
|
|||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._state_storage_controller = self._storage_controllers.state
|
||||
self._device_handler = hs.get_device_handler()
|
||||
self._task_scheduler = hs.get_task_scheduler()
|
||||
|
||||
self.should_calculate_push_rules = hs.config.push.enable_push
|
||||
|
||||
|
@ -360,11 +362,19 @@ class SyncHandler:
|
|||
# (since we now know that the device has received them)
|
||||
if since_token is not None:
|
||||
since_stream_id = since_token.to_device_key
|
||||
deleted = await self.store.delete_messages_for_device(
|
||||
sync_config.user.to_string(), sync_config.device_id, since_stream_id
|
||||
# Delete device messages asynchronously and in batches using the task scheduler
|
||||
await self._task_scheduler.schedule_task(
|
||||
DELETE_DEVICE_MSGS_TASK_NAME,
|
||||
resource_id=sync_config.device_id,
|
||||
params={
|
||||
"user_id": sync_config.user.to_string(),
|
||||
"device_id": sync_config.device_id,
|
||||
"up_to_stream_id": since_stream_id,
|
||||
},
|
||||
)
|
||||
logger.debug(
|
||||
"Deleted %d to-device messages up to %d", deleted, since_stream_id
|
||||
"Deletion of to-device messages up to %d scheduled",
|
||||
since_stream_id,
|
||||
)
|
||||
|
||||
if timeout == 0 or since_token is None or full_state:
|
||||
|
|
|
@ -445,13 +445,18 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
|
||||
@trace
|
||||
async def delete_messages_for_device(
|
||||
self, user_id: str, device_id: Optional[str], up_to_stream_id: int
|
||||
self,
|
||||
user_id: str,
|
||||
device_id: Optional[str],
|
||||
up_to_stream_id: int,
|
||||
limit: int,
|
||||
) -> int:
|
||||
"""
|
||||
Args:
|
||||
user_id: The recipient user_id.
|
||||
device_id: The recipient device_id.
|
||||
up_to_stream_id: Where to delete messages up to.
|
||||
limit: maximum number of messages to delete
|
||||
|
||||
Returns:
|
||||
The number of messages deleted.
|
||||
|
@ -472,12 +477,16 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
log_kv({"message": "No changes in cache since last check"})
|
||||
return 0
|
||||
|
||||
ROW_ID_NAME = self.database_engine.row_id_name
|
||||
|
||||
def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
|
||||
sql = (
|
||||
"DELETE FROM device_inbox"
|
||||
" WHERE user_id = ? AND device_id = ?"
|
||||
" AND stream_id <= ?"
|
||||
)
|
||||
sql = f"""
|
||||
DELETE FROM device_inbox WHERE {ROW_ID_NAME} IN (
|
||||
SELECT {ROW_ID_NAME} FROM device_inbox
|
||||
WHERE user_id = ? AND device_id = ? AND stream_id <= ?
|
||||
LIMIT {limit}
|
||||
)
|
||||
"""
|
||||
txn.execute(sql, (user_id, device_id, up_to_stream_id))
|
||||
return txn.rowcount
|
||||
|
||||
|
@ -487,6 +496,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||
|
||||
log_kv({"message": f"deleted {count} messages for device", "count": count})
|
||||
|
||||
# In this case we don't know if we hit the limit or the delete is complete
|
||||
# so let's not update the cache.
|
||||
if count == limit:
|
||||
return count
|
||||
|
||||
# Update the cache, ensuring that we only ever increase the value
|
||||
updated_last_deleted_stream_id = self._last_device_delete_cache.get(
|
||||
(user_id, device_id), 0
|
||||
|
|
|
@ -1766,14 +1766,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
keyvalues={"user_id": user_id, "hidden": False},
|
||||
)
|
||||
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="device_inbox",
|
||||
column="device_id",
|
||||
values=device_ids,
|
||||
keyvalues={"user_id": user_id},
|
||||
)
|
||||
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="device_auth_providers",
|
||||
|
|
|
@ -939,11 +939,7 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
|
|||
receipts."""
|
||||
|
||||
def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None:
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
ROW_ID_NAME = "ctid"
|
||||
else:
|
||||
ROW_ID_NAME = "rowid"
|
||||
|
||||
ROW_ID_NAME = self.database_engine.row_id_name
|
||||
# Identify any duplicate receipts arising from
|
||||
# https://github.com/matrix-org/synapse/issues/14406.
|
||||
# The following query takes less than a minute on matrix.org.
|
||||
|
|
|
@ -100,6 +100,12 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM
|
|||
"""Gets a string giving the server version. For example: '3.22.0'"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def row_id_name(self) -> str:
|
||||
"""Gets the literal name representing a row id for this engine."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def in_transaction(self, conn: ConnectionType) -> bool:
|
||||
"""Whether the connection is currently in a transaction."""
|
||||
|
|
|
@ -211,6 +211,10 @@ class PostgresEngine(
|
|||
else:
|
||||
return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100)
|
||||
|
||||
@property
|
||||
def row_id_name(self) -> str:
|
||||
return "ctid"
|
||||
|
||||
def in_transaction(self, conn: psycopg2.extensions.connection) -> bool:
|
||||
return conn.status != psycopg2.extensions.STATUS_READY
|
||||
|
||||
|
|
|
@ -123,6 +123,10 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]):
|
|||
"""Gets a string giving the server version. For example: '3.22.0'."""
|
||||
return "%i.%i.%i" % sqlite3.sqlite_version_info
|
||||
|
||||
@property
|
||||
def row_id_name(self) -> str:
|
||||
return "rowid"
|
||||
|
||||
def in_transaction(self, conn: sqlite3.Connection) -> bool:
|
||||
return conn.in_transaction
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
||||
from synapse.storage.engines import BaseDatabaseEngine
|
||||
from synapse.storage.prepare_database import get_statements
|
||||
|
||||
FIX_INDEXES = """
|
||||
|
@ -37,7 +37,7 @@ CREATE INDEX group_rooms_r_idx ON group_rooms(room_id);
|
|||
|
||||
|
||||
def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None:
|
||||
rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid"
|
||||
rowid = database_engine.row_id_name
|
||||
|
||||
# remove duplicates from group_users & group_invites tables
|
||||
cur.execute(
|
||||
|
|
|
@ -77,6 +77,7 @@ class TaskScheduler:
|
|||
LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000 # 24hrs
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._hs = hs
|
||||
self._store = hs.get_datastores().main
|
||||
self._clock = hs.get_clock()
|
||||
self._running_tasks: Set[str] = set()
|
||||
|
@ -97,8 +98,6 @@ class TaskScheduler:
|
|||
"handle_scheduled_tasks",
|
||||
self._handle_scheduled_tasks,
|
||||
)
|
||||
else:
|
||||
self.replication_client = hs.get_replication_command_handler()
|
||||
|
||||
def register_action(
|
||||
self,
|
||||
|
@ -133,7 +132,7 @@ class TaskScheduler:
|
|||
params: Optional[JsonMapping] = None,
|
||||
) -> str:
|
||||
"""Schedule a new potentially resumable task. A function matching the specified
|
||||
`action` should have been previously registered with `register_action`.
|
||||
`action` should have be registered with `register_action` before the task is run.
|
||||
|
||||
Args:
|
||||
action: the name of a previously registered action
|
||||
|
@ -149,11 +148,6 @@ class TaskScheduler:
|
|||
Returns:
|
||||
The id of the scheduled task
|
||||
"""
|
||||
if action not in self._actions:
|
||||
raise Exception(
|
||||
f"No function associated with action {action} of the scheduled task"
|
||||
)
|
||||
|
||||
status = TaskStatus.SCHEDULED
|
||||
if timestamp is None or timestamp < self._clock.time_msec():
|
||||
timestamp = self._clock.time_msec()
|
||||
|
@ -175,7 +169,7 @@ class TaskScheduler:
|
|||
if self._run_background_tasks:
|
||||
await self._launch_task(task)
|
||||
else:
|
||||
self.replication_client.send_new_active_task(task.id)
|
||||
self._hs.get_replication_command_handler().send_new_active_task(task.id)
|
||||
|
||||
return task.id
|
||||
|
||||
|
@ -315,7 +309,10 @@ class TaskScheduler:
|
|||
"""
|
||||
assert self._run_background_tasks
|
||||
|
||||
assert task.action in self._actions
|
||||
if task.action not in self._actions:
|
||||
raise Exception(
|
||||
f"No function associated with action {task.action} of the scheduled task {task.id}"
|
||||
)
|
||||
function = self._actions[task.action]
|
||||
|
||||
async def wrapper() -> None:
|
||||
|
|
|
@ -30,6 +30,7 @@ from synapse.server import HomeServer
|
|||
from synapse.storage.databases.main.appservice import _make_exclusive_regex
|
||||
from synapse.types import JsonDict, create_requester
|
||||
from synapse.util import Clock
|
||||
from synapse.util.task_scheduler import TaskScheduler
|
||||
|
||||
from tests import unittest
|
||||
from tests.unittest import override_config
|
||||
|
@ -49,6 +50,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
|||
assert isinstance(handler, DeviceHandler)
|
||||
self.handler = handler
|
||||
self.store = hs.get_datastores().main
|
||||
self.device_message_handler = hs.get_device_message_handler()
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
|
@ -211,6 +213,51 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
self.assertIsNone(res)
|
||||
|
||||
def test_delete_device_and_big_device_inbox(self) -> None:
|
||||
"""Check that deleting a big device inbox is staged and batched asynchronously."""
|
||||
DEVICE_ID = "abc"
|
||||
sender = "@sender:" + self.hs.hostname
|
||||
receiver = "@receiver:" + self.hs.hostname
|
||||
self._record_user(sender, DEVICE_ID, DEVICE_ID)
|
||||
self._record_user(receiver, DEVICE_ID, DEVICE_ID)
|
||||
|
||||
# queue a bunch of messages in the inbox
|
||||
requester = create_requester(sender, device_id=DEVICE_ID)
|
||||
for i in range(0, DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT + 10):
|
||||
self.get_success(
|
||||
self.device_message_handler.send_device_message(
|
||||
requester, "message_type", {receiver: {"*": {"val": i}}}
|
||||
)
|
||||
)
|
||||
|
||||
# delete the device
|
||||
self.get_success(self.handler.delete_devices(receiver, [DEVICE_ID]))
|
||||
|
||||
# messages should be deleted up to DEVICE_MSGS_DELETE_BATCH_LIMIT straight away
|
||||
res = self.get_success(
|
||||
self.store.db_pool.simple_select_list(
|
||||
table="device_inbox",
|
||||
keyvalues={"user_id": receiver},
|
||||
retcols=("user_id", "device_id", "stream_id"),
|
||||
desc="get_device_id_from_device_inbox",
|
||||
)
|
||||
)
|
||||
self.assertEqual(10, len(res))
|
||||
|
||||
# wait for the task scheduler to do a second delete pass
|
||||
self.reactor.advance(TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)
|
||||
|
||||
# remaining messages should now be deleted
|
||||
res = self.get_success(
|
||||
self.store.db_pool.simple_select_list(
|
||||
table="device_inbox",
|
||||
keyvalues={"user_id": receiver},
|
||||
retcols=("user_id", "device_id", "stream_id"),
|
||||
desc="get_device_id_from_device_inbox",
|
||||
)
|
||||
)
|
||||
self.assertEqual(0, len(res))
|
||||
|
||||
def test_update_device(self) -> None:
|
||||
self._record_users()
|
||||
|
||||
|
|
Loading…
Reference in a new issue