Move additional tasks to the background worker, part 4 (#8513)

This commit is contained in:
Patrick Cloke 2020-10-13 08:20:32 -04:00 committed by GitHub
parent b2486f6656
commit 629a951b49
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 199 additions and 224 deletions

1
changelog.d/8513.feature Normal file
View file

@ -0,0 +1 @@
Allow running background tasks in a separate worker process.

View file

@ -70,7 +70,8 @@ class AccountValidityHandler:
"send_renewals", self._send_renewal_emails
)
self.clock.looping_call(send_emails, 30 * 60 * 1000)
if hs.config.run_background_tasks:
self.clock.looping_call(send_emails, 30 * 60 * 1000)
async def _send_renewal_emails(self):
"""Gets the list of users whose account is expiring in the amount of time

View file

@ -45,7 +45,7 @@ class DeactivateAccountHandler(BaseHandler):
# Start the user parter loop so it can resume parting users from rooms where
# it left off (if it has work left to do).
if hs.config.worker_app is None:
if hs.config.run_background_tasks:
hs.get_reactor().callWhenRunning(self._start_user_parting)
self._account_validity_enabled = hs.config.account_validity.enabled

View file

@ -402,21 +402,23 @@ class EventCreationHandler:
self.config.block_events_without_consent_error
)
# Rooms which should be excluded from dummy insertion. (For instance,
# those without local users who can send events into the room).
#
# map from room id to time-of-last-attempt.
#
self._rooms_to_exclude_from_dummy_event_insertion = {} # type: Dict[str, int]
# we need to construct a ConsentURIBuilder here, as it checks that the necessary
# config options, but *only* if we have a configuration for which we are
# going to need it.
if self._block_events_without_consent_error:
self._consent_uri_builder = ConsentURIBuilder(self.config)
# Rooms which should be excluded from dummy insertion. (For instance,
# those without local users who can send events into the room).
#
# map from room id to time-of-last-attempt.
#
self._rooms_to_exclude_from_dummy_event_insertion = {} # type: Dict[str, int]
# The number of forward extremeities before a dummy event is sent.
self._dummy_events_threshold = hs.config.dummy_events_threshold
if (
not self.config.worker_app
self.config.run_background_tasks
and self.config.cleanup_extremities_with_dummy_events
):
self.clock.looping_call(
@ -431,8 +433,6 @@ class EventCreationHandler:
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
self._dummy_events_threshold = hs.config.dummy_events_threshold
async def create_event(
self,
requester: Requester,

View file

@ -92,7 +92,7 @@ class PaginationHandler:
self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min
self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max
if hs.config.retention_enabled:
if hs.config.run_background_tasks and hs.config.retention_enabled:
# Run the purge jobs described in the configuration file.
for job in hs.config.retention_purge_jobs:
logger.info("Setting up purge job with config: %s", job)

View file

@ -35,14 +35,16 @@ MAX_DISPLAYNAME_LEN = 256
MAX_AVATAR_URL_LEN = 1000
class BaseProfileHandler(BaseHandler):
class ProfileHandler(BaseHandler):
"""Handles fetching and updating user profile information.
BaseProfileHandler can be instantiated directly on workers and will
delegate to master when necessary. The master process should use the
subclass MasterProfileHandler
ProfileHandler can be instantiated directly on workers and will
delegate to master when necessary.
"""
PROFILE_UPDATE_MS = 60 * 1000
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
def __init__(self, hs):
super().__init__(hs)
@ -53,6 +55,11 @@ class BaseProfileHandler(BaseHandler):
self.user_directory_handler = hs.get_user_directory_handler()
if hs.config.run_background_tasks:
self.clock.looping_call(
self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS
)
async def get_profile(self, user_id):
target_user = UserID.from_string(user_id)
@ -363,20 +370,6 @@ class BaseProfileHandler(BaseHandler):
raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN)
raise
class MasterProfileHandler(BaseProfileHandler):
PROFILE_UPDATE_MS = 60 * 1000
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
def __init__(self, hs):
super().__init__(hs)
assert hs.config.worker_app is None
self.clock.looping_call(
self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS
)
def _start_update_remote_profile_cache(self):
return run_as_background_process(
"Update remote profile", self._update_remote_profile_cache

View file

@ -75,7 +75,7 @@ from synapse.handlers.message import EventCreationHandler, MessageHandler
from synapse.handlers.pagination import PaginationHandler
from synapse.handlers.password_policy import PasswordPolicyHandler
from synapse.handlers.presence import PresenceHandler
from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler
from synapse.handlers.profile import ProfileHandler
from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.receipts import ReceiptsHandler
from synapse.handlers.register import RegistrationHandler
@ -191,7 +191,12 @@ class HomeServer(metaclass=abc.ABCMeta):
"""
REQUIRED_ON_BACKGROUND_TASK_STARTUP = [
"account_validity",
"auth",
"deactivate_account",
"message",
"pagination",
"profile",
"stats",
]
@ -462,10 +467,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_profile_handler(self):
if self.config.worker_app:
return BaseProfileHandler(self)
else:
return MasterProfileHandler(self)
return ProfileHandler(self)
@cache_in_self
def get_event_creation_handler(self) -> EventCreationHandler:

View file

@ -91,27 +91,6 @@ class ProfileWorkerStore(SQLBaseStore):
desc="set_profile_avatar_url",
)
class ProfileStore(ProfileWorkerStore):
async def add_remote_profile_cache(
self, user_id: str, displayname: str, avatar_url: str
) -> None:
"""Ensure we are caching the remote user's profiles.
This should only be called when `is_subscribed_remote_profile_for_user`
would return true for the user.
"""
await self.db_pool.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
"displayname": displayname,
"avatar_url": avatar_url,
"last_check": self._clock.time_msec(),
},
desc="add_remote_profile_cache",
)
async def update_remote_profile_cache(
self, user_id: str, displayname: str, avatar_url: str
) -> int:
@ -138,28 +117,6 @@ class ProfileStore(ProfileWorkerStore):
desc="delete_remote_profile_cache",
)
async def get_remote_profile_cache_entries_that_expire(
self, last_checked: int
) -> Dict[str, str]:
"""Get all users who haven't been checked since `last_checked`
"""
def _get_remote_profile_cache_entries_that_expire_txn(txn):
sql = """
SELECT user_id, displayname, avatar_url
FROM remote_profile_cache
WHERE last_check < ?
"""
txn.execute(sql, (last_checked,))
return self.db_pool.cursor_to_dict(txn)
return await self.db_pool.runInteraction(
"get_remote_profile_cache_entries_that_expire",
_get_remote_profile_cache_entries_that_expire_txn,
)
async def is_subscribed_remote_profile_for_user(self, user_id):
"""Check whether we are interested in a remote user's profile.
"""
@ -184,3 +141,46 @@ class ProfileStore(ProfileWorkerStore):
if res:
return True
async def get_remote_profile_cache_entries_that_expire(
self, last_checked: int
) -> Dict[str, str]:
"""Get all users who haven't been checked since `last_checked`
"""
def _get_remote_profile_cache_entries_that_expire_txn(txn):
sql = """
SELECT user_id, displayname, avatar_url
FROM remote_profile_cache
WHERE last_check < ?
"""
txn.execute(sql, (last_checked,))
return self.db_pool.cursor_to_dict(txn)
return await self.db_pool.runInteraction(
"get_remote_profile_cache_entries_that_expire",
_get_remote_profile_cache_entries_that_expire_txn,
)
class ProfileStore(ProfileWorkerStore):
async def add_remote_profile_cache(
self, user_id: str, displayname: str, avatar_url: str
) -> None:
"""Ensure we are caching the remote user's profiles.
This should only be called when `is_subscribed_remote_profile_for_user`
would return true for the user.
"""
await self.db_pool.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
"displayname": displayname,
"avatar_url": avatar_url,
"last_check": self._clock.time_msec(),
},
desc="add_remote_profile_cache",
)

View file

@ -862,6 +862,32 @@ class RegistrationWorkerStore(SQLBaseStore):
values={"expiration_ts_ms": expiration_ts, "email_sent": False},
)
async def get_user_pending_deactivation(self) -> Optional[str]:
"""
Gets one user from the table of users waiting to be parted from all the rooms
they're in.
"""
return await self.db_pool.simple_select_one_onecol(
"users_pending_deactivation",
keyvalues={},
retcol="user_id",
allow_none=True,
desc="get_users_pending_deactivation",
)
async def del_user_pending_deactivation(self, user_id: str) -> None:
"""
Removes the given user to the table of users who need to be parted from all the
rooms they're in, effectively marking that user as fully deactivated.
"""
# XXX: This should be simple_delete_one but we failed to put a unique index on
# the table, so somehow duplicate entries have ended up in it.
await self.db_pool.simple_delete(
"users_pending_deactivation",
keyvalues={"user_id": user_id},
desc="del_user_pending_deactivation",
)
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
@ -1371,32 +1397,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="add_user_pending_deactivation",
)
async def del_user_pending_deactivation(self, user_id: str) -> None:
"""
Removes the given user to the table of users who need to be parted from all the
rooms they're in, effectively marking that user as fully deactivated.
"""
# XXX: This should be simple_delete_one but we failed to put a unique index on
# the table, so somehow duplicate entries have ended up in it.
await self.db_pool.simple_delete(
"users_pending_deactivation",
keyvalues={"user_id": user_id},
desc="del_user_pending_deactivation",
)
async def get_user_pending_deactivation(self) -> Optional[str]:
"""
Gets one user from the table of users waiting to be parted from all the rooms
they're in.
"""
return await self.db_pool.simple_select_one_onecol(
"users_pending_deactivation",
keyvalues={},
retcol="user_id",
allow_none=True,
desc="get_users_pending_deactivation",
)
async def validate_threepid_session(
self, session_id: str, client_secret: str, token: str, current_ts: int
) -> Optional[str]:

View file

@ -869,6 +869,89 @@ class RoomWorkerStore(SQLBaseStore):
"get_all_new_public_rooms", get_all_new_public_rooms
)
async def get_rooms_for_retention_period_in_range(
self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
) -> Dict[str, dict]:
"""Retrieves all of the rooms within the given retention range.
Optionally includes the rooms which don't have a retention policy.
Args:
min_ms: Duration in milliseconds that define the lower limit of
the range to handle (exclusive). If None, doesn't set a lower limit.
max_ms: Duration in milliseconds that define the upper limit of
the range to handle (inclusive). If None, doesn't set an upper limit.
include_null: Whether to include rooms which retention policy is NULL
in the returned set.
Returns:
The rooms within this range, along with their retention
policy. The key is "room_id", and maps to a dict describing the retention
policy associated with this room ID. The keys for this nested dict are
"min_lifetime" (int|None), and "max_lifetime" (int|None).
"""
def get_rooms_for_retention_period_in_range_txn(txn):
range_conditions = []
args = []
if min_ms is not None:
range_conditions.append("max_lifetime > ?")
args.append(min_ms)
if max_ms is not None:
range_conditions.append("max_lifetime <= ?")
args.append(max_ms)
# Do a first query which will retrieve the rooms that have a retention policy
# in their current state.
sql = """
SELECT room_id, min_lifetime, max_lifetime FROM room_retention
INNER JOIN current_state_events USING (event_id, room_id)
"""
if len(range_conditions):
sql += " WHERE (" + " AND ".join(range_conditions) + ")"
if include_null:
sql += " OR max_lifetime IS NULL"
txn.execute(sql, args)
rows = self.db_pool.cursor_to_dict(txn)
rooms_dict = {}
for row in rows:
rooms_dict[row["room_id"]] = {
"min_lifetime": row["min_lifetime"],
"max_lifetime": row["max_lifetime"],
}
if include_null:
# If required, do a second query that retrieves all of the rooms we know
# of so we can handle rooms with no retention policy.
sql = "SELECT DISTINCT room_id FROM current_state_events"
txn.execute(sql)
rows = self.db_pool.cursor_to_dict(txn)
# If a room isn't already in the dict (i.e. it doesn't have a retention
# policy in its state), add it with a null policy.
for row in rows:
if row["room_id"] not in rooms_dict:
rooms_dict[row["room_id"]] = {
"min_lifetime": None,
"max_lifetime": None,
}
return rooms_dict
return await self.db_pool.runInteraction(
"get_rooms_for_retention_period_in_range",
get_rooms_for_retention_period_in_range_txn,
)
class RoomBackgroundUpdateStore(SQLBaseStore):
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@ -1446,88 +1529,3 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
self.is_room_blocked,
(room_id,),
)
async def get_rooms_for_retention_period_in_range(
self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
) -> Dict[str, dict]:
"""Retrieves all of the rooms within the given retention range.
Optionally includes the rooms which don't have a retention policy.
Args:
min_ms: Duration in milliseconds that define the lower limit of
the range to handle (exclusive). If None, doesn't set a lower limit.
max_ms: Duration in milliseconds that define the upper limit of
the range to handle (inclusive). If None, doesn't set an upper limit.
include_null: Whether to include rooms which retention policy is NULL
in the returned set.
Returns:
The rooms within this range, along with their retention
policy. The key is "room_id", and maps to a dict describing the retention
policy associated with this room ID. The keys for this nested dict are
"min_lifetime" (int|None), and "max_lifetime" (int|None).
"""
def get_rooms_for_retention_period_in_range_txn(txn):
range_conditions = []
args = []
if min_ms is not None:
range_conditions.append("max_lifetime > ?")
args.append(min_ms)
if max_ms is not None:
range_conditions.append("max_lifetime <= ?")
args.append(max_ms)
# Do a first query which will retrieve the rooms that have a retention policy
# in their current state.
sql = """
SELECT room_id, min_lifetime, max_lifetime FROM room_retention
INNER JOIN current_state_events USING (event_id, room_id)
"""
if len(range_conditions):
sql += " WHERE (" + " AND ".join(range_conditions) + ")"
if include_null:
sql += " OR max_lifetime IS NULL"
txn.execute(sql, args)
rows = self.db_pool.cursor_to_dict(txn)
rooms_dict = {}
for row in rows:
rooms_dict[row["room_id"]] = {
"min_lifetime": row["min_lifetime"],
"max_lifetime": row["max_lifetime"],
}
if include_null:
# If required, do a second query that retrieves all of the rooms we know
# of so we can handle rooms with no retention policy.
sql = "SELECT DISTINCT room_id FROM current_state_events"
txn.execute(sql)
rows = self.db_pool.cursor_to_dict(txn)
# If a room isn't already in the dict (i.e. it doesn't have a retention
# policy in its state), add it with a null policy.
for row in rows:
if row["room_id"] not in rooms_dict:
rooms_dict[row["room_id"]] = {
"min_lifetime": None,
"max_lifetime": None,
}
return rooms_dict
rooms = await self.db_pool.runInteraction(
"get_rooms_for_retention_period_in_range",
get_rooms_for_retention_period_in_range_txn,
)
return rooms

View file

@ -65,26 +65,6 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
mock_federation_client = Mock(spec=["put_json"])
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
datastores = Mock()
datastores.main = Mock(
spec=[
# Bits that Federation needs
"prep_send_transaction",
"delivered_txn",
"get_received_txn_response",
"set_received_txn_response",
"get_destination_last_successful_stream_ordering",
"get_destination_retry_timings",
"get_devices_by_remote",
"maybe_store_room_on_invite",
# Bits that user_directory needs
"get_user_directory_stream_pos",
"get_current_state_deltas",
"get_device_updates_by_remote",
"get_room_max_stream_ordering",
]
)
# the tests assume that we are starting at unix time 1000
reactor.pump((1000,))
@ -95,8 +75,6 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
replication_streams={},
)
hs.datastores = datastores
return hs
def prepare(self, reactor, clock, hs):
@ -114,16 +92,16 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"retry_interval": 0,
"failure_ts": None,
}
self.datastore.get_destination_retry_timings.return_value = defer.succeed(
retry_timings_res
self.datastore.get_destination_retry_timings = Mock(
return_value=defer.succeed(retry_timings_res)
)
self.datastore.get_device_updates_by_remote.return_value = make_awaitable(
(0, [])
self.datastore.get_device_updates_by_remote = Mock(
return_value=make_awaitable((0, []))
)
self.datastore.get_destination_last_successful_stream_ordering.return_value = make_awaitable(
None
self.datastore.get_destination_last_successful_stream_ordering = Mock(
return_value=make_awaitable(None)
)
def get_received_txn_response(*args):
@ -145,17 +123,19 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
def get_users_in_room(room_id):
return defer.succeed({str(u) for u in self.room_members})
async def get_users_in_room(room_id):
return {str(u) for u in self.room_members}
self.datastore.get_users_in_room = get_users_in_room
self.datastore.get_user_directory_stream_pos.side_effect = (
# we deliberately return a non-None stream pos to avoid doing an initial_spam
lambda: make_awaitable(1)
self.datastore.get_user_directory_stream_pos = Mock(
side_effect=(
# we deliberately return a non-None stream pos to avoid doing an initial_spam
lambda: make_awaitable(1)
)
)
self.datastore.get_current_state_deltas.return_value = (0, None)
self.datastore.get_current_state_deltas = Mock(return_value=(0, None))
self.datastore.get_to_device_stream_token = lambda: 0
self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(