mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-29 07:28:55 +03:00
Move additional tasks to the background worker, part 4 (#8513)
This commit is contained in:
parent
b2486f6656
commit
629a951b49
11 changed files with 199 additions and 224 deletions
1
changelog.d/8513.feature
Normal file
1
changelog.d/8513.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Allow running background tasks in a separate worker process.
|
|
@ -70,7 +70,8 @@ class AccountValidityHandler:
|
||||||
"send_renewals", self._send_renewal_emails
|
"send_renewals", self._send_renewal_emails
|
||||||
)
|
)
|
||||||
|
|
||||||
self.clock.looping_call(send_emails, 30 * 60 * 1000)
|
if hs.config.run_background_tasks:
|
||||||
|
self.clock.looping_call(send_emails, 30 * 60 * 1000)
|
||||||
|
|
||||||
async def _send_renewal_emails(self):
|
async def _send_renewal_emails(self):
|
||||||
"""Gets the list of users whose account is expiring in the amount of time
|
"""Gets the list of users whose account is expiring in the amount of time
|
||||||
|
|
|
@ -45,7 +45,7 @@ class DeactivateAccountHandler(BaseHandler):
|
||||||
|
|
||||||
# Start the user parter loop so it can resume parting users from rooms where
|
# Start the user parter loop so it can resume parting users from rooms where
|
||||||
# it left off (if it has work left to do).
|
# it left off (if it has work left to do).
|
||||||
if hs.config.worker_app is None:
|
if hs.config.run_background_tasks:
|
||||||
hs.get_reactor().callWhenRunning(self._start_user_parting)
|
hs.get_reactor().callWhenRunning(self._start_user_parting)
|
||||||
|
|
||||||
self._account_validity_enabled = hs.config.account_validity.enabled
|
self._account_validity_enabled = hs.config.account_validity.enabled
|
||||||
|
|
|
@ -402,21 +402,23 @@ class EventCreationHandler:
|
||||||
self.config.block_events_without_consent_error
|
self.config.block_events_without_consent_error
|
||||||
)
|
)
|
||||||
|
|
||||||
# Rooms which should be excluded from dummy insertion. (For instance,
|
|
||||||
# those without local users who can send events into the room).
|
|
||||||
#
|
|
||||||
# map from room id to time-of-last-attempt.
|
|
||||||
#
|
|
||||||
self._rooms_to_exclude_from_dummy_event_insertion = {} # type: Dict[str, int]
|
|
||||||
|
|
||||||
# we need to construct a ConsentURIBuilder here, as it checks that the necessary
|
# we need to construct a ConsentURIBuilder here, as it checks that the necessary
|
||||||
# config options, but *only* if we have a configuration for which we are
|
# config options, but *only* if we have a configuration for which we are
|
||||||
# going to need it.
|
# going to need it.
|
||||||
if self._block_events_without_consent_error:
|
if self._block_events_without_consent_error:
|
||||||
self._consent_uri_builder = ConsentURIBuilder(self.config)
|
self._consent_uri_builder = ConsentURIBuilder(self.config)
|
||||||
|
|
||||||
|
# Rooms which should be excluded from dummy insertion. (For instance,
|
||||||
|
# those without local users who can send events into the room).
|
||||||
|
#
|
||||||
|
# map from room id to time-of-last-attempt.
|
||||||
|
#
|
||||||
|
self._rooms_to_exclude_from_dummy_event_insertion = {} # type: Dict[str, int]
|
||||||
|
# The number of forward extremeities before a dummy event is sent.
|
||||||
|
self._dummy_events_threshold = hs.config.dummy_events_threshold
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not self.config.worker_app
|
self.config.run_background_tasks
|
||||||
and self.config.cleanup_extremities_with_dummy_events
|
and self.config.cleanup_extremities_with_dummy_events
|
||||||
):
|
):
|
||||||
self.clock.looping_call(
|
self.clock.looping_call(
|
||||||
|
@ -431,8 +433,6 @@ class EventCreationHandler:
|
||||||
|
|
||||||
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
|
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
|
||||||
|
|
||||||
self._dummy_events_threshold = hs.config.dummy_events_threshold
|
|
||||||
|
|
||||||
async def create_event(
|
async def create_event(
|
||||||
self,
|
self,
|
||||||
requester: Requester,
|
requester: Requester,
|
||||||
|
|
|
@ -92,7 +92,7 @@ class PaginationHandler:
|
||||||
self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min
|
self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min
|
||||||
self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max
|
self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max
|
||||||
|
|
||||||
if hs.config.retention_enabled:
|
if hs.config.run_background_tasks and hs.config.retention_enabled:
|
||||||
# Run the purge jobs described in the configuration file.
|
# Run the purge jobs described in the configuration file.
|
||||||
for job in hs.config.retention_purge_jobs:
|
for job in hs.config.retention_purge_jobs:
|
||||||
logger.info("Setting up purge job with config: %s", job)
|
logger.info("Setting up purge job with config: %s", job)
|
||||||
|
|
|
@ -35,14 +35,16 @@ MAX_DISPLAYNAME_LEN = 256
|
||||||
MAX_AVATAR_URL_LEN = 1000
|
MAX_AVATAR_URL_LEN = 1000
|
||||||
|
|
||||||
|
|
||||||
class BaseProfileHandler(BaseHandler):
|
class ProfileHandler(BaseHandler):
|
||||||
"""Handles fetching and updating user profile information.
|
"""Handles fetching and updating user profile information.
|
||||||
|
|
||||||
BaseProfileHandler can be instantiated directly on workers and will
|
ProfileHandler can be instantiated directly on workers and will
|
||||||
delegate to master when necessary. The master process should use the
|
delegate to master when necessary.
|
||||||
subclass MasterProfileHandler
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
PROFILE_UPDATE_MS = 60 * 1000
|
||||||
|
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
|
||||||
|
@ -53,6 +55,11 @@ class BaseProfileHandler(BaseHandler):
|
||||||
|
|
||||||
self.user_directory_handler = hs.get_user_directory_handler()
|
self.user_directory_handler = hs.get_user_directory_handler()
|
||||||
|
|
||||||
|
if hs.config.run_background_tasks:
|
||||||
|
self.clock.looping_call(
|
||||||
|
self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS
|
||||||
|
)
|
||||||
|
|
||||||
async def get_profile(self, user_id):
|
async def get_profile(self, user_id):
|
||||||
target_user = UserID.from_string(user_id)
|
target_user = UserID.from_string(user_id)
|
||||||
|
|
||||||
|
@ -363,20 +370,6 @@ class BaseProfileHandler(BaseHandler):
|
||||||
raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN)
|
raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
class MasterProfileHandler(BaseProfileHandler):
|
|
||||||
PROFILE_UPDATE_MS = 60 * 1000
|
|
||||||
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
|
|
||||||
|
|
||||||
def __init__(self, hs):
|
|
||||||
super().__init__(hs)
|
|
||||||
|
|
||||||
assert hs.config.worker_app is None
|
|
||||||
|
|
||||||
self.clock.looping_call(
|
|
||||||
self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS
|
|
||||||
)
|
|
||||||
|
|
||||||
def _start_update_remote_profile_cache(self):
|
def _start_update_remote_profile_cache(self):
|
||||||
return run_as_background_process(
|
return run_as_background_process(
|
||||||
"Update remote profile", self._update_remote_profile_cache
|
"Update remote profile", self._update_remote_profile_cache
|
||||||
|
|
|
@ -75,7 +75,7 @@ from synapse.handlers.message import EventCreationHandler, MessageHandler
|
||||||
from synapse.handlers.pagination import PaginationHandler
|
from synapse.handlers.pagination import PaginationHandler
|
||||||
from synapse.handlers.password_policy import PasswordPolicyHandler
|
from synapse.handlers.password_policy import PasswordPolicyHandler
|
||||||
from synapse.handlers.presence import PresenceHandler
|
from synapse.handlers.presence import PresenceHandler
|
||||||
from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler
|
from synapse.handlers.profile import ProfileHandler
|
||||||
from synapse.handlers.read_marker import ReadMarkerHandler
|
from synapse.handlers.read_marker import ReadMarkerHandler
|
||||||
from synapse.handlers.receipts import ReceiptsHandler
|
from synapse.handlers.receipts import ReceiptsHandler
|
||||||
from synapse.handlers.register import RegistrationHandler
|
from synapse.handlers.register import RegistrationHandler
|
||||||
|
@ -191,7 +191,12 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
REQUIRED_ON_BACKGROUND_TASK_STARTUP = [
|
REQUIRED_ON_BACKGROUND_TASK_STARTUP = [
|
||||||
|
"account_validity",
|
||||||
"auth",
|
"auth",
|
||||||
|
"deactivate_account",
|
||||||
|
"message",
|
||||||
|
"pagination",
|
||||||
|
"profile",
|
||||||
"stats",
|
"stats",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -462,10 +467,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_profile_handler(self):
|
def get_profile_handler(self):
|
||||||
if self.config.worker_app:
|
return ProfileHandler(self)
|
||||||
return BaseProfileHandler(self)
|
|
||||||
else:
|
|
||||||
return MasterProfileHandler(self)
|
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_event_creation_handler(self) -> EventCreationHandler:
|
def get_event_creation_handler(self) -> EventCreationHandler:
|
||||||
|
|
|
@ -91,27 +91,6 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||||
desc="set_profile_avatar_url",
|
desc="set_profile_avatar_url",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ProfileStore(ProfileWorkerStore):
|
|
||||||
async def add_remote_profile_cache(
|
|
||||||
self, user_id: str, displayname: str, avatar_url: str
|
|
||||||
) -> None:
|
|
||||||
"""Ensure we are caching the remote user's profiles.
|
|
||||||
|
|
||||||
This should only be called when `is_subscribed_remote_profile_for_user`
|
|
||||||
would return true for the user.
|
|
||||||
"""
|
|
||||||
await self.db_pool.simple_upsert(
|
|
||||||
table="remote_profile_cache",
|
|
||||||
keyvalues={"user_id": user_id},
|
|
||||||
values={
|
|
||||||
"displayname": displayname,
|
|
||||||
"avatar_url": avatar_url,
|
|
||||||
"last_check": self._clock.time_msec(),
|
|
||||||
},
|
|
||||||
desc="add_remote_profile_cache",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def update_remote_profile_cache(
|
async def update_remote_profile_cache(
|
||||||
self, user_id: str, displayname: str, avatar_url: str
|
self, user_id: str, displayname: str, avatar_url: str
|
||||||
) -> int:
|
) -> int:
|
||||||
|
@ -138,28 +117,6 @@ class ProfileStore(ProfileWorkerStore):
|
||||||
desc="delete_remote_profile_cache",
|
desc="delete_remote_profile_cache",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_remote_profile_cache_entries_that_expire(
|
|
||||||
self, last_checked: int
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
"""Get all users who haven't been checked since `last_checked`
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _get_remote_profile_cache_entries_that_expire_txn(txn):
|
|
||||||
sql = """
|
|
||||||
SELECT user_id, displayname, avatar_url
|
|
||||||
FROM remote_profile_cache
|
|
||||||
WHERE last_check < ?
|
|
||||||
"""
|
|
||||||
|
|
||||||
txn.execute(sql, (last_checked,))
|
|
||||||
|
|
||||||
return self.db_pool.cursor_to_dict(txn)
|
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
|
||||||
"get_remote_profile_cache_entries_that_expire",
|
|
||||||
_get_remote_profile_cache_entries_that_expire_txn,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def is_subscribed_remote_profile_for_user(self, user_id):
|
async def is_subscribed_remote_profile_for_user(self, user_id):
|
||||||
"""Check whether we are interested in a remote user's profile.
|
"""Check whether we are interested in a remote user's profile.
|
||||||
"""
|
"""
|
||||||
|
@ -184,3 +141,46 @@ class ProfileStore(ProfileWorkerStore):
|
||||||
|
|
||||||
if res:
|
if res:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def get_remote_profile_cache_entries_that_expire(
|
||||||
|
self, last_checked: int
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
"""Get all users who haven't been checked since `last_checked`
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_remote_profile_cache_entries_that_expire_txn(txn):
|
||||||
|
sql = """
|
||||||
|
SELECT user_id, displayname, avatar_url
|
||||||
|
FROM remote_profile_cache
|
||||||
|
WHERE last_check < ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
txn.execute(sql, (last_checked,))
|
||||||
|
|
||||||
|
return self.db_pool.cursor_to_dict(txn)
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"get_remote_profile_cache_entries_that_expire",
|
||||||
|
_get_remote_profile_cache_entries_that_expire_txn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ProfileStore(ProfileWorkerStore):
|
||||||
|
async def add_remote_profile_cache(
|
||||||
|
self, user_id: str, displayname: str, avatar_url: str
|
||||||
|
) -> None:
|
||||||
|
"""Ensure we are caching the remote user's profiles.
|
||||||
|
|
||||||
|
This should only be called when `is_subscribed_remote_profile_for_user`
|
||||||
|
would return true for the user.
|
||||||
|
"""
|
||||||
|
await self.db_pool.simple_upsert(
|
||||||
|
table="remote_profile_cache",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
values={
|
||||||
|
"displayname": displayname,
|
||||||
|
"avatar_url": avatar_url,
|
||||||
|
"last_check": self._clock.time_msec(),
|
||||||
|
},
|
||||||
|
desc="add_remote_profile_cache",
|
||||||
|
)
|
||||||
|
|
|
@ -862,6 +862,32 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||||
values={"expiration_ts_ms": expiration_ts, "email_sent": False},
|
values={"expiration_ts_ms": expiration_ts, "email_sent": False},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def get_user_pending_deactivation(self) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Gets one user from the table of users waiting to be parted from all the rooms
|
||||||
|
they're in.
|
||||||
|
"""
|
||||||
|
return await self.db_pool.simple_select_one_onecol(
|
||||||
|
"users_pending_deactivation",
|
||||||
|
keyvalues={},
|
||||||
|
retcol="user_id",
|
||||||
|
allow_none=True,
|
||||||
|
desc="get_users_pending_deactivation",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def del_user_pending_deactivation(self, user_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Removes the given user to the table of users who need to be parted from all the
|
||||||
|
rooms they're in, effectively marking that user as fully deactivated.
|
||||||
|
"""
|
||||||
|
# XXX: This should be simple_delete_one but we failed to put a unique index on
|
||||||
|
# the table, so somehow duplicate entries have ended up in it.
|
||||||
|
await self.db_pool.simple_delete(
|
||||||
|
"users_pending_deactivation",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
desc="del_user_pending_deactivation",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||||
|
@ -1371,32 +1397,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||||
desc="add_user_pending_deactivation",
|
desc="add_user_pending_deactivation",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def del_user_pending_deactivation(self, user_id: str) -> None:
|
|
||||||
"""
|
|
||||||
Removes the given user to the table of users who need to be parted from all the
|
|
||||||
rooms they're in, effectively marking that user as fully deactivated.
|
|
||||||
"""
|
|
||||||
# XXX: This should be simple_delete_one but we failed to put a unique index on
|
|
||||||
# the table, so somehow duplicate entries have ended up in it.
|
|
||||||
await self.db_pool.simple_delete(
|
|
||||||
"users_pending_deactivation",
|
|
||||||
keyvalues={"user_id": user_id},
|
|
||||||
desc="del_user_pending_deactivation",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_user_pending_deactivation(self) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Gets one user from the table of users waiting to be parted from all the rooms
|
|
||||||
they're in.
|
|
||||||
"""
|
|
||||||
return await self.db_pool.simple_select_one_onecol(
|
|
||||||
"users_pending_deactivation",
|
|
||||||
keyvalues={},
|
|
||||||
retcol="user_id",
|
|
||||||
allow_none=True,
|
|
||||||
desc="get_users_pending_deactivation",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def validate_threepid_session(
|
async def validate_threepid_session(
|
||||||
self, session_id: str, client_secret: str, token: str, current_ts: int
|
self, session_id: str, client_secret: str, token: str, current_ts: int
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
|
|
|
@ -869,6 +869,89 @@ class RoomWorkerStore(SQLBaseStore):
|
||||||
"get_all_new_public_rooms", get_all_new_public_rooms
|
"get_all_new_public_rooms", get_all_new_public_rooms
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def get_rooms_for_retention_period_in_range(
|
||||||
|
self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
|
||||||
|
) -> Dict[str, dict]:
|
||||||
|
"""Retrieves all of the rooms within the given retention range.
|
||||||
|
|
||||||
|
Optionally includes the rooms which don't have a retention policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_ms: Duration in milliseconds that define the lower limit of
|
||||||
|
the range to handle (exclusive). If None, doesn't set a lower limit.
|
||||||
|
max_ms: Duration in milliseconds that define the upper limit of
|
||||||
|
the range to handle (inclusive). If None, doesn't set an upper limit.
|
||||||
|
include_null: Whether to include rooms which retention policy is NULL
|
||||||
|
in the returned set.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The rooms within this range, along with their retention
|
||||||
|
policy. The key is "room_id", and maps to a dict describing the retention
|
||||||
|
policy associated with this room ID. The keys for this nested dict are
|
||||||
|
"min_lifetime" (int|None), and "max_lifetime" (int|None).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_rooms_for_retention_period_in_range_txn(txn):
|
||||||
|
range_conditions = []
|
||||||
|
args = []
|
||||||
|
|
||||||
|
if min_ms is not None:
|
||||||
|
range_conditions.append("max_lifetime > ?")
|
||||||
|
args.append(min_ms)
|
||||||
|
|
||||||
|
if max_ms is not None:
|
||||||
|
range_conditions.append("max_lifetime <= ?")
|
||||||
|
args.append(max_ms)
|
||||||
|
|
||||||
|
# Do a first query which will retrieve the rooms that have a retention policy
|
||||||
|
# in their current state.
|
||||||
|
sql = """
|
||||||
|
SELECT room_id, min_lifetime, max_lifetime FROM room_retention
|
||||||
|
INNER JOIN current_state_events USING (event_id, room_id)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(range_conditions):
|
||||||
|
sql += " WHERE (" + " AND ".join(range_conditions) + ")"
|
||||||
|
|
||||||
|
if include_null:
|
||||||
|
sql += " OR max_lifetime IS NULL"
|
||||||
|
|
||||||
|
txn.execute(sql, args)
|
||||||
|
|
||||||
|
rows = self.db_pool.cursor_to_dict(txn)
|
||||||
|
rooms_dict = {}
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
rooms_dict[row["room_id"]] = {
|
||||||
|
"min_lifetime": row["min_lifetime"],
|
||||||
|
"max_lifetime": row["max_lifetime"],
|
||||||
|
}
|
||||||
|
|
||||||
|
if include_null:
|
||||||
|
# If required, do a second query that retrieves all of the rooms we know
|
||||||
|
# of so we can handle rooms with no retention policy.
|
||||||
|
sql = "SELECT DISTINCT room_id FROM current_state_events"
|
||||||
|
|
||||||
|
txn.execute(sql)
|
||||||
|
|
||||||
|
rows = self.db_pool.cursor_to_dict(txn)
|
||||||
|
|
||||||
|
# If a room isn't already in the dict (i.e. it doesn't have a retention
|
||||||
|
# policy in its state), add it with a null policy.
|
||||||
|
for row in rows:
|
||||||
|
if row["room_id"] not in rooms_dict:
|
||||||
|
rooms_dict[row["room_id"]] = {
|
||||||
|
"min_lifetime": None,
|
||||||
|
"max_lifetime": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
return rooms_dict
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"get_rooms_for_retention_period_in_range",
|
||||||
|
get_rooms_for_retention_period_in_range_txn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RoomBackgroundUpdateStore(SQLBaseStore):
|
class RoomBackgroundUpdateStore(SQLBaseStore):
|
||||||
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
|
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
|
||||||
|
@ -1446,88 +1529,3 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
||||||
self.is_room_blocked,
|
self.is_room_blocked,
|
||||||
(room_id,),
|
(room_id,),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_rooms_for_retention_period_in_range(
|
|
||||||
self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
|
|
||||||
) -> Dict[str, dict]:
|
|
||||||
"""Retrieves all of the rooms within the given retention range.
|
|
||||||
|
|
||||||
Optionally includes the rooms which don't have a retention policy.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
min_ms: Duration in milliseconds that define the lower limit of
|
|
||||||
the range to handle (exclusive). If None, doesn't set a lower limit.
|
|
||||||
max_ms: Duration in milliseconds that define the upper limit of
|
|
||||||
the range to handle (inclusive). If None, doesn't set an upper limit.
|
|
||||||
include_null: Whether to include rooms which retention policy is NULL
|
|
||||||
in the returned set.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The rooms within this range, along with their retention
|
|
||||||
policy. The key is "room_id", and maps to a dict describing the retention
|
|
||||||
policy associated with this room ID. The keys for this nested dict are
|
|
||||||
"min_lifetime" (int|None), and "max_lifetime" (int|None).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_rooms_for_retention_period_in_range_txn(txn):
|
|
||||||
range_conditions = []
|
|
||||||
args = []
|
|
||||||
|
|
||||||
if min_ms is not None:
|
|
||||||
range_conditions.append("max_lifetime > ?")
|
|
||||||
args.append(min_ms)
|
|
||||||
|
|
||||||
if max_ms is not None:
|
|
||||||
range_conditions.append("max_lifetime <= ?")
|
|
||||||
args.append(max_ms)
|
|
||||||
|
|
||||||
# Do a first query which will retrieve the rooms that have a retention policy
|
|
||||||
# in their current state.
|
|
||||||
sql = """
|
|
||||||
SELECT room_id, min_lifetime, max_lifetime FROM room_retention
|
|
||||||
INNER JOIN current_state_events USING (event_id, room_id)
|
|
||||||
"""
|
|
||||||
|
|
||||||
if len(range_conditions):
|
|
||||||
sql += " WHERE (" + " AND ".join(range_conditions) + ")"
|
|
||||||
|
|
||||||
if include_null:
|
|
||||||
sql += " OR max_lifetime IS NULL"
|
|
||||||
|
|
||||||
txn.execute(sql, args)
|
|
||||||
|
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
|
||||||
rooms_dict = {}
|
|
||||||
|
|
||||||
for row in rows:
|
|
||||||
rooms_dict[row["room_id"]] = {
|
|
||||||
"min_lifetime": row["min_lifetime"],
|
|
||||||
"max_lifetime": row["max_lifetime"],
|
|
||||||
}
|
|
||||||
|
|
||||||
if include_null:
|
|
||||||
# If required, do a second query that retrieves all of the rooms we know
|
|
||||||
# of so we can handle rooms with no retention policy.
|
|
||||||
sql = "SELECT DISTINCT room_id FROM current_state_events"
|
|
||||||
|
|
||||||
txn.execute(sql)
|
|
||||||
|
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
|
||||||
|
|
||||||
# If a room isn't already in the dict (i.e. it doesn't have a retention
|
|
||||||
# policy in its state), add it with a null policy.
|
|
||||||
for row in rows:
|
|
||||||
if row["room_id"] not in rooms_dict:
|
|
||||||
rooms_dict[row["room_id"]] = {
|
|
||||||
"min_lifetime": None,
|
|
||||||
"max_lifetime": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
return rooms_dict
|
|
||||||
|
|
||||||
rooms = await self.db_pool.runInteraction(
|
|
||||||
"get_rooms_for_retention_period_in_range",
|
|
||||||
get_rooms_for_retention_period_in_range_txn,
|
|
||||||
)
|
|
||||||
|
|
||||||
return rooms
|
|
||||||
|
|
|
@ -65,26 +65,6 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
mock_federation_client = Mock(spec=["put_json"])
|
mock_federation_client = Mock(spec=["put_json"])
|
||||||
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
|
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
|
||||||
|
|
||||||
datastores = Mock()
|
|
||||||
datastores.main = Mock(
|
|
||||||
spec=[
|
|
||||||
# Bits that Federation needs
|
|
||||||
"prep_send_transaction",
|
|
||||||
"delivered_txn",
|
|
||||||
"get_received_txn_response",
|
|
||||||
"set_received_txn_response",
|
|
||||||
"get_destination_last_successful_stream_ordering",
|
|
||||||
"get_destination_retry_timings",
|
|
||||||
"get_devices_by_remote",
|
|
||||||
"maybe_store_room_on_invite",
|
|
||||||
# Bits that user_directory needs
|
|
||||||
"get_user_directory_stream_pos",
|
|
||||||
"get_current_state_deltas",
|
|
||||||
"get_device_updates_by_remote",
|
|
||||||
"get_room_max_stream_ordering",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# the tests assume that we are starting at unix time 1000
|
# the tests assume that we are starting at unix time 1000
|
||||||
reactor.pump((1000,))
|
reactor.pump((1000,))
|
||||||
|
|
||||||
|
@ -95,8 +75,6 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
replication_streams={},
|
replication_streams={},
|
||||||
)
|
)
|
||||||
|
|
||||||
hs.datastores = datastores
|
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor, clock, hs):
|
||||||
|
@ -114,16 +92,16 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
"retry_interval": 0,
|
"retry_interval": 0,
|
||||||
"failure_ts": None,
|
"failure_ts": None,
|
||||||
}
|
}
|
||||||
self.datastore.get_destination_retry_timings.return_value = defer.succeed(
|
self.datastore.get_destination_retry_timings = Mock(
|
||||||
retry_timings_res
|
return_value=defer.succeed(retry_timings_res)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.datastore.get_device_updates_by_remote.return_value = make_awaitable(
|
self.datastore.get_device_updates_by_remote = Mock(
|
||||||
(0, [])
|
return_value=make_awaitable((0, []))
|
||||||
)
|
)
|
||||||
|
|
||||||
self.datastore.get_destination_last_successful_stream_ordering.return_value = make_awaitable(
|
self.datastore.get_destination_last_successful_stream_ordering = Mock(
|
||||||
None
|
return_value=make_awaitable(None)
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_received_txn_response(*args):
|
def get_received_txn_response(*args):
|
||||||
|
@ -145,17 +123,19 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
|
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
|
||||||
|
|
||||||
def get_users_in_room(room_id):
|
async def get_users_in_room(room_id):
|
||||||
return defer.succeed({str(u) for u in self.room_members})
|
return {str(u) for u in self.room_members}
|
||||||
|
|
||||||
self.datastore.get_users_in_room = get_users_in_room
|
self.datastore.get_users_in_room = get_users_in_room
|
||||||
|
|
||||||
self.datastore.get_user_directory_stream_pos.side_effect = (
|
self.datastore.get_user_directory_stream_pos = Mock(
|
||||||
# we deliberately return a non-None stream pos to avoid doing an initial_spam
|
side_effect=(
|
||||||
lambda: make_awaitable(1)
|
# we deliberately return a non-None stream pos to avoid doing an initial_spam
|
||||||
|
lambda: make_awaitable(1)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.datastore.get_current_state_deltas.return_value = (0, None)
|
self.datastore.get_current_state_deltas = Mock(return_value=(0, None))
|
||||||
|
|
||||||
self.datastore.get_to_device_stream_token = lambda: 0
|
self.datastore.get_to_device_stream_token = lambda: 0
|
||||||
self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
|
self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
|
||||||
|
|
Loading…
Reference in a new issue