diff --git a/changelog.d/16288.bugfix b/changelog.d/16288.bugfix new file mode 100644 index 0000000000..f08d10d1f3 --- /dev/null +++ b/changelog.d/16288.bugfix @@ -0,0 +1 @@ +Fix bug introduced in Synapse 1.49.0 when using dehydrated devices ([MSC2697](https://github.com/matrix-org/matrix-spec-proposals/pull/2697)) and refresh tokens. Contributed by Hanadi. diff --git a/changelog.d/16301.misc b/changelog.d/16301.misc new file mode 100644 index 0000000000..93ceaeafc9 --- /dev/null +++ b/changelog.d/16301.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/changelog.d/16304.doc b/changelog.d/16304.doc new file mode 100644 index 0000000000..53660ec9a4 --- /dev/null +++ b/changelog.d/16304.doc @@ -0,0 +1 @@ +Link to the Alpine Linux community package for Synapse. diff --git a/changelog.d/16313.misc b/changelog.d/16313.misc new file mode 100644 index 0000000000..4f266c1fb0 --- /dev/null +++ b/changelog.d/16313.misc @@ -0,0 +1 @@ +Delete device messages asynchronously and in staged batches using the task scheduler. diff --git a/changelog.d/16314.misc b/changelog.d/16314.misc new file mode 100644 index 0000000000..a32b07112a --- /dev/null +++ b/changelog.d/16314.misc @@ -0,0 +1 @@ +Remove a reference cycle for in background processes. diff --git a/changelog.d/16316.misc b/changelog.d/16316.misc new file mode 100644 index 0000000000..aa0644f278 --- /dev/null +++ b/changelog.d/16316.misc @@ -0,0 +1 @@ +Refactor `get_user_by_id`. diff --git a/changelog.d/16318.misc b/changelog.d/16318.misc new file mode 100644 index 0000000000..1433a2f246 --- /dev/null +++ b/changelog.d/16318.misc @@ -0,0 +1 @@ +Speed up task to delete to-device messages. diff --git a/docs/setup/installation.md b/docs/setup/installation.md index 0357d2a0fb..1f13864a8f 100644 --- a/docs/setup/installation.md +++ b/docs/setup/installation.md @@ -155,6 +155,14 @@ sudo pip uninstall py-bcrypt sudo pip install py-bcrypt ``` +#### Alpine Linux + +6543 maintains [Synapse packages for Alpine Linux](https://pkgs.alpinelinux.org/packages?name=synapse&branch=edge) in the community repository. Install with: + +```sh +sudo apk add synapse +``` + #### Void Linux Synapse can be found in the void repositories as diff --git a/synapse/api/auth/internal.py b/synapse/api/auth/internal.py index 6a5fd44ec0..a75f6f2cc4 100644 --- a/synapse/api/auth/internal.py +++ b/synapse/api/auth/internal.py @@ -268,7 +268,7 @@ class InternalAuth(BaseAuth): stored_user = await self.store.get_user_by_id(user_id) if not stored_user: raise InvalidClientTokenError("Unknown user_id %s" % user_id) - if not stored_user["is_guest"]: + if not stored_user.is_guest: raise InvalidClientTokenError( "Guest access token used for regular user" ) diff --git a/synapse/api/auth/msc3861_delegated.py b/synapse/api/auth/msc3861_delegated.py index ef5d3f9b81..31bb035cc8 100644 --- a/synapse/api/auth/msc3861_delegated.py +++ b/synapse/api/auth/msc3861_delegated.py @@ -300,7 +300,7 @@ class MSC3861DelegatedAuth(BaseAuth): user_id = UserID(username, self._hostname) # First try to find a user from the username claim - user_info = await self.store.get_userinfo_by_id(user_id=user_id.to_string()) + user_info = await self.store.get_user_by_id(user_id=user_id.to_string()) if user_info is None: # If the user does not exist, we should create it on the fly # TODO: we could use SCIM to provision users ahead of time and listen diff --git a/synapse/app/_base.py b/synapse/app/_base.py index a94b57a671..9ac7e4313e 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -27,9 +27,7 @@ from typing import ( Any, Awaitable, Callable, - Collection, Dict, - Iterable, List, NoReturn, Optional, @@ -76,7 +74,7 @@ from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_ from synapse.module_api.callbacks.third_party_event_rules_callbacks import ( load_legacy_third_party_event_rules, ) -from synapse.types import ISynapseReactor +from synapse.types import ISynapseReactor, StrCollection from synapse.util import SYNAPSE_VERSION from synapse.util.caches.lrucache import setup_expire_lru_cache_entries from synapse.util.daemonize import daemonize_process @@ -278,7 +276,7 @@ def register_start( reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper())) -def listen_metrics(bind_addresses: Iterable[str], port: int) -> None: +def listen_metrics(bind_addresses: StrCollection, port: int) -> None: """ Start Prometheus metrics server. """ @@ -315,7 +313,7 @@ def _set_prometheus_client_use_created_metrics(new_value: bool) -> None: def listen_manhole( - bind_addresses: Collection[str], + bind_addresses: StrCollection, port: int, manhole_settings: ManholeConfig, manhole_globals: dict, @@ -339,7 +337,7 @@ def listen_manhole( def listen_tcp( - bind_addresses: Collection[str], + bind_addresses: StrCollection, port: int, factory: ServerFactory, reactor: IReactorTCP = reactor, @@ -448,7 +446,7 @@ def listen_http( def listen_ssl( - bind_addresses: Collection[str], + bind_addresses: StrCollection, port: int, factory: ServerFactory, context_factory: IOpenSSLContextFactory, diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 58856839e1..c5816105f4 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -26,7 +26,6 @@ from textwrap import dedent from typing import ( Any, ClassVar, - Collection, Dict, Iterable, Iterator, @@ -384,7 +383,7 @@ class RootConfig: config_classes: List[Type[Config]] = [] - def __init__(self, config_files: Collection[str] = ()): + def __init__(self, config_files: StrSequence = ()): # Capture absolute paths here, so we can reload config after we daemonize. self.config_files = [os.path.abspath(path) for path in config_files] diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 35257a3b1b..3c1777b7ec 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -25,7 +25,6 @@ from typing import ( Iterable, List, Optional, - Sequence, Tuple, Type, TypeVar, @@ -408,7 +407,7 @@ class EventBase(metaclass=abc.ABCMeta): def keys(self) -> Iterable[str]: return self._dict.keys() - def prev_event_ids(self) -> Sequence[str]: + def prev_event_ids(self) -> List[str]: """Returns the list of prev event IDs. The order matches the order specified in the event, though there is no meaning to it. @@ -553,7 +552,7 @@ class FrozenEventV2(EventBase): self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1]) return self._event_id - def prev_event_ids(self) -> Sequence[str]: + def prev_event_ids(self) -> List[str]: """Returns the list of prev event IDs. The order matches the order specified in the event, though there is no meaning to it. diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 14ea0e6640..1165c017ba 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import attr from signedjson.types import SigningKey @@ -28,7 +28,7 @@ from synapse.event_auth import auth_types_for_event from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict from synapse.state import StateHandler from synapse.storage.databases.main import DataStore -from synapse.types import EventID, JsonDict +from synapse.types import EventID, JsonDict, StrCollection from synapse.types.state import StateFilter from synapse.util import Clock from synapse.util.stringutils import random_string @@ -103,7 +103,7 @@ class EventBuilder: async def build( self, - prev_event_ids: Collection[str], + prev_event_ids: StrCollection, auth_event_ids: Optional[List[str]], depth: Optional[int] = None, ) -> EventBase: @@ -136,7 +136,7 @@ class EventBuilder: format_version = self.room_version.event_format # The types of auth/prev events changes between event versions. - prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]] + prev_events: Union[StrCollection, List[Tuple[str, Dict[str, str]]]] auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] if format_version == EventFormatVersions.ROOM_V1_V2: auth_events = await self._store.add_event_hashes(auth_event_ids) diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 34625dd7a1..5da50cb0d2 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections.abc -from typing import Iterable, List, Type, Union, cast +from typing import List, Type, Union, cast import jsonschema from pydantic import Field, StrictBool, StrictStr @@ -36,7 +36,7 @@ from synapse.events.utils import ( from synapse.federation.federation_server import server_matches_acl_event from synapse.http.servlet import validate_json_object from synapse.rest.models import RequestBodyModel -from synapse.types import EventID, JsonDict, RoomID, UserID +from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID class EventValidator: @@ -225,7 +225,7 @@ class EventValidator: self._ensure_state_event(event) - def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None: + def _ensure_strings(self, d: JsonDict, keys: StrCollection) -> None: for s in keys: if s not in d: raise SynapseError(400, "'%s' not in content" % (s,)) diff --git a/synapse/handlers/account.py b/synapse/handlers/account.py index c05a14304c..fa043cca86 100644 --- a/synapse/handlers/account.py +++ b/synapse/handlers/account.py @@ -102,7 +102,7 @@ class AccountHandler: """ status = {"exists": False} - userinfo = await self._main_store.get_userinfo_by_id(user_id.to_string()) + userinfo = await self._main_store.get_user_by_id(user_id.to_string()) if userinfo is not None: status = { diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 2f0e5f3b0a..7092ff3449 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set from synapse.api.constants import Direction, Membership from synapse.events import EventBase -from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID +from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID, UserInfo from synapse.visibility import filter_events_for_client if TYPE_CHECKING: @@ -57,38 +57,30 @@ class AdminHandler: async def get_user(self, user: UserID) -> Optional[JsonDict]: """Function to get user details""" - user_info_dict = await self._store.get_user_by_id(user.to_string()) - if user_info_dict is None: + user_info: Optional[UserInfo] = await self._store.get_user_by_id( + user.to_string() + ) + if user_info is None: return None - # Restrict returned information to a known set of fields. This prevents additional - # fields added to get_user_by_id from modifying Synapse's external API surface. - user_info_to_return = { - "name", - "admin", - "deactivated", - "locked", - "shadow_banned", - "creation_ts", - "appservice_id", - "consent_server_notice_sent", - "consent_version", - "consent_ts", - "user_type", - "is_guest", - "last_seen_ts", + user_info_dict = { + "name": user.to_string(), + "admin": user_info.is_admin, + "deactivated": user_info.is_deactivated, + "locked": user_info.locked, + "shadow_banned": user_info.is_shadow_banned, + "creation_ts": user_info.creation_ts, + "appservice_id": user_info.appservice_id, + "consent_server_notice_sent": user_info.consent_server_notice_sent, + "consent_version": user_info.consent_version, + "consent_ts": user_info.consent_ts, + "user_type": user_info.user_type, + "is_guest": user_info.is_guest, } if self._msc3866_enabled: # Only include the approved flag if support for MSC3866 is enabled. - user_info_to_return.add("approved") - - # Restrict returned keys to a known set. - user_info_dict = { - key: value - for key, value in user_info_dict.items() - if key in user_info_to_return - } + user_info_dict["approved"] = user_info.approved # Add additional user metadata profile = await self._store.get_profileinfo(user) @@ -105,6 +97,9 @@ class AdminHandler: user_info_dict["external_ids"] = external_ids user_info_dict["erased"] = await self._store.is_user_erased(user.to_string()) + last_seen_ts = await self._store.get_last_seen_for_user_id(user.to_string()) + user_info_dict["last_seen_ts"] = last_seen_ts + return user_info_dict async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any: diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index e2ae3da67e..86ad96d030 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -388,7 +388,8 @@ class DeviceWorkerHandler: "Trying handling device list state for partial join: not supported on workers." ) - DEVICE_MSGS_DELETE_BATCH_LIMIT = 100 + DEVICE_MSGS_DELETE_BATCH_LIMIT = 1000 + DEVICE_MSGS_DELETE_SLEEP_MS = 1000 async def _delete_device_messages( self, @@ -400,19 +401,19 @@ class DeviceWorkerHandler: device_id = task.params["device_id"] up_to_stream_id = task.params["up_to_stream_id"] - res = await self.store.delete_messages_for_device( - user_id=user_id, - device_id=device_id, - up_to_stream_id=up_to_stream_id, - limit=DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT, - ) + # Delete the messages in batches to avoid too much DB load. + while True: + res = await self.store.delete_messages_for_device( + user_id=user_id, + device_id=device_id, + up_to_stream_id=up_to_stream_id, + limit=DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT, + ) - if res < DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT: - return TaskStatus.COMPLETE, None, None - else: - # There is probably still device messages to be deleted, let's keep the task active and it will be run - # again in a subsequent scheduler loop run (probably the next one, if not too many tasks are running). - return TaskStatus.ACTIVE, None, None + if res < DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT: + return TaskStatus.COMPLETE, None, None + + await self.clock.sleep(DeviceHandler.DEVICE_MSGS_DELETE_SLEEP_MS / 1000.0) class DeviceHandler(DeviceWorkerHandler): @@ -758,12 +759,13 @@ class DeviceHandler(DeviceWorkerHandler): # If the dehydrated device was successfully deleted (the device ID # matched the stored dehydrated device), then modify the access - # token to use the dehydrated device's ID and copy the old device - # display name to the dehydrated device, and destroy the old device - # ID + # token and refresh token to use the dehydrated device's ID and + # copy the old device display name to the dehydrated device, + # and destroy the old device ID old_device_id = await self.store.set_device_for_access_token( access_token, device_id ) + await self.store.set_device_for_refresh_token(user_id, old_device_id, device_id) old_device = await self.store.get_device(user_id, old_device_id) if old_device is None: raise errors.NotFoundError() diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index d6be18cdef..c036578a3d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -828,13 +828,13 @@ class EventCreationHandler: u = await self.store.get_user_by_id(user_id) assert u is not None - if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT): + if u.user_type in (UserTypes.SUPPORT, UserTypes.BOT): # support and bot users are not required to consent return - if u["appservice_id"] is not None: + if u.appservice_id is not None: # users registered by an appservice are exempt return - if u["consent_version"] == self.config.consent.user_consent_version: + if u.consent_version == self.config.consent.user_consent_version: return consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart) diff --git a/synapse/http/client.py b/synapse/http/client.py index ca2cdbc6e2..c750e03b36 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -78,7 +78,7 @@ from synapse.http.replicationagent import ReplicationAgent from synapse.http.types import QueryParams from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import set_tag, start_active_span, tags -from synapse.types import ISynapseReactor +from synapse.types import ISynapseReactor, StrSequence from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred @@ -108,10 +108,9 @@ RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValu # the value actually has to be a List, but List is invariant so we can't specify that # the entries can either be Lists or bytes. RawHeaderValue = Union[ - List[str], + StrSequence, List[bytes], List[Union[str, bytes]], - Tuple[str, ...], Tuple[bytes, ...], Tuple[Union[str, bytes], ...], ] diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index fc62793628..5d79d31579 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -18,7 +18,6 @@ import logging from http import HTTPStatus from typing import ( TYPE_CHECKING, - Iterable, List, Mapping, Optional, @@ -38,7 +37,7 @@ from twisted.web.server import Request from synapse.api.errors import Codes, SynapseError from synapse.http import redact_uri from synapse.http.server import HttpServer -from synapse.types import JsonDict, RoomAlias, RoomID +from synapse.types import JsonDict, RoomAlias, RoomID, StrCollection from synapse.util import json_decoder if TYPE_CHECKING: @@ -340,7 +339,7 @@ def parse_string( name: str, default: str, *, - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> str: ... @@ -352,7 +351,7 @@ def parse_string( name: str, *, required: Literal[True], - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> str: ... @@ -365,7 +364,7 @@ def parse_string( *, default: Optional[str] = None, required: bool = False, - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> Optional[str]: ... @@ -376,7 +375,7 @@ def parse_string( name: str, default: Optional[str] = None, required: bool = False, - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> Optional[str]: """ @@ -485,7 +484,7 @@ def parse_enum( def _parse_string_value( value: bytes, - allowed_values: Optional[Iterable[str]], + allowed_values: Optional[StrCollection], name: str, encoding: str, ) -> str: @@ -511,7 +510,7 @@ def parse_strings_from_args( args: Mapping[bytes, Sequence[bytes]], name: str, *, - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> Optional[List[str]]: ... @@ -523,7 +522,7 @@ def parse_strings_from_args( name: str, default: List[str], *, - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> List[str]: ... @@ -535,7 +534,7 @@ def parse_strings_from_args( name: str, *, required: Literal[True], - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> List[str]: ... @@ -548,7 +547,7 @@ def parse_strings_from_args( default: Optional[List[str]] = None, *, required: bool = False, - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> Optional[List[str]]: ... @@ -559,7 +558,7 @@ def parse_strings_from_args( name: str, default: Optional[List[str]] = None, required: bool = False, - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> Optional[List[str]]: """ @@ -610,7 +609,7 @@ def parse_string_from_args( name: str, default: Optional[str] = None, *, - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> Optional[str]: ... @@ -623,7 +622,7 @@ def parse_string_from_args( default: Optional[str] = None, *, required: Literal[True], - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> str: ... @@ -635,7 +634,7 @@ def parse_string_from_args( name: str, default: Optional[str] = None, required: bool = False, - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> Optional[str]: ... @@ -646,7 +645,7 @@ def parse_string_from_args( name: str, default: Optional[str] = None, required: bool = False, - allowed_values: Optional[Iterable[str]] = None, + allowed_values: Optional[StrCollection] = None, encoding: str = "ascii", ) -> Optional[str]: """ @@ -821,7 +820,7 @@ def parse_and_validate_json_object_from_request( return validate_json_object(content, model_type) -def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None: +def assert_params_in_dict(body: JsonDict, required: StrCollection) -> None: absent = [] for k in required: if k not in body: diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 39fc629937..3cf2fbc3e2 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -25,7 +25,6 @@ from typing import ( Iterable, Mapping, Optional, - Sequence, Set, Tuple, Type, @@ -49,6 +48,7 @@ import synapse.metrics._reactor_metrics # noqa: F401 from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager from synapse.metrics._twisted_exposition import MetricsResource, generate_latest from synapse.metrics._types import Collector +from synapse.types import StrSequence from synapse.util import SYNAPSE_VERSION logger = logging.getLogger(__name__) @@ -81,7 +81,7 @@ class LaterGauge(Collector): name: str desc: str - labels: Optional[Sequence[str]] = attr.ib(hash=False) + labels: Optional[StrSequence] = attr.ib(hash=False) # callback: should either return a value (if there are no labels for this metric), # or dict mapping from a label tuple to a value caller: Callable[ @@ -143,8 +143,8 @@ class InFlightGauge(Generic[MetricsEntry], Collector): self, name: str, desc: str, - labels: Sequence[str], - sub_metrics: Sequence[str], + labels: StrSequence, + sub_metrics: StrSequence, ): self.name = name self.desc = desc diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 9ea4e23b31..f1f1f0cdf9 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -322,13 +322,21 @@ class BackgroundProcessLoggingContext(LoggingContext): if instance_id is None: instance_id = id(self) super().__init__("%s-%s" % (name, instance_id)) - self._proc = _BackgroundProcess(name, self) + self._proc: Optional[_BackgroundProcess] = _BackgroundProcess(name, self) def start(self, rusage: "Optional[resource.struct_rusage]") -> None: """Log context has started running (again).""" super().start(rusage) + if self._proc is None: + logger.error( + "Background process re-entered without a proc: %s", + self.name, + stack_info=True, + ) + return + # We've become active again so we make sure we're in the list of active # procs. (Note that "start" here means we've become active, as opposed # to starting for the first time.) @@ -345,6 +353,14 @@ class BackgroundProcessLoggingContext(LoggingContext): super().__exit__(type, value, traceback) + if self._proc is None: + logger.error( + "Background process exited without a proc: %s", + self.name, + stack_info=True, + ) + return + # The background process has finished. We explicitly remove and manually # update the metrics here so that if nothing is scraping metrics the set # doesn't infinitely grow. @@ -352,3 +368,6 @@ class BackgroundProcessLoggingContext(LoggingContext): _background_processes_active_since_last_scrape.discard(self._proc) self._proc.update_metrics() + + # Set proc to None to break the reference cycle. + self._proc = None diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index d6efe10a28..7ec202be23 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -572,7 +572,7 @@ class ModuleApi: Returns: UserInfo object if a user was found, otherwise None """ - return await self._store.get_userinfo_by_id(user_id) + return await self._store.get_user_by_id(user_id) async def get_user_by_req( self, @@ -1878,7 +1878,7 @@ class AccountDataManager: raise TypeError(f"new_data must be a dict; got {type(new_data).__name__}") # Ensure the user exists, so we don't just write to users that aren't there. - if await self._store.get_userinfo_by_id(user_id) is None: + if await self._store.get_user_by_id(user_id) is None: raise ValueError(f"User {user_id} does not exist on this server.") await self._handler.add_account_data_for_user(user_id, data_type, new_data) diff --git a/synapse/notifier.py b/synapse/notifier.py index 68115bca70..fc39e5c963 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -104,7 +104,7 @@ class _NotifierUserStream: def __init__( self, user_id: str, - rooms: Collection[str], + rooms: StrCollection, current_token: StreamToken, time_now_ms: int, ): @@ -457,7 +457,7 @@ class Notifier: stream_key: str, new_token: Union[int, RoomStreamToken], users: Optional[Collection[Union[str, UserID]]] = None, - rooms: Optional[Collection[str]] = None, + rooms: Optional[StrCollection] = None, ) -> None: """Used to inform listeners that something has happened event wise. @@ -529,7 +529,7 @@ class Notifier: user_id: str, timeout: int, callback: Callable[[StreamToken, StreamToken], Awaitable[T]], - room_ids: Optional[Collection[str]] = None, + room_ids: Optional[StrCollection] = None, from_token: StreamToken = StreamToken.START, ) -> T: """Wait until the callback returns a non empty response or the diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 5642666411..b668bb5da1 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -672,14 +672,12 @@ class ReplicationCommandHandler: cmd.instance_name, cmd.lock_name, cmd.lock_key ) - async def on_NEW_ACTIVE_TASK( + def on_NEW_ACTIVE_TASK( self, conn: IReplicationConnection, cmd: NewActiveTaskCommand ) -> None: """Called when get a new NEW_ACTIVE_TASK command.""" if self._task_scheduler: - task = await self._task_scheduler.get_task(cmd.data) - if task: - await self._task_scheduler._launch_task(task) + self._task_scheduler.launch_task_by_id(cmd.data) def new_connection(self, connection: IReplicationConnection) -> None: """Called when we have a new connection.""" diff --git a/synapse/rest/client/_base.py b/synapse/rest/client/_base.py index 5c1c19e1f3..73c568ef75 100644 --- a/synapse/rest/client/_base.py +++ b/synapse/rest/client/_base.py @@ -20,14 +20,14 @@ from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar, from synapse.api.errors import InteractiveAuthIncompleteError from synapse.api.urls import CLIENT_API_PREFIX -from synapse.types import JsonDict +from synapse.types import JsonDict, StrCollection logger = logging.getLogger(__name__) def client_patterns( path_regex: str, - releases: Iterable[str] = ("r0", "v3"), + releases: StrCollection = ("r0", "v3"), unstable: bool = True, v1: bool = False, ) -> Iterable[Pattern]: diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 25f9ea285b..88d3ec1baf 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -129,7 +129,7 @@ class ConsentResource(DirectServeHtmlResource): if u is None: raise NotFoundError("Unknown user") - has_consented = u["consent_version"] == version + has_consented = u.consent_version == version userhmac = userhmac_bytes.decode("ascii") try: diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index 94025ba41f..a879b6505e 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -79,15 +79,15 @@ class ConsentServerNotices: if u is None: return - if u["is_guest"] and not self._send_to_guests: + if u.is_guest and not self._send_to_guests: # don't send to guests return - if u["consent_version"] == self._current_consent_version: + if u.consent_version == self._current_consent_version: # user has already consented return - if u["consent_server_notice_sent"] == self._current_consent_version: + if u.consent_server_notice_sent == self._current_consent_version: # we've already sent a notice to the user return diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 1b91cf5eaa..e977ed1044 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -20,7 +20,6 @@ from typing import ( Any, Awaitable, Callable, - Collection, DefaultDict, Dict, FrozenSet, @@ -49,7 +48,7 @@ from synapse.logging.opentracing import tag_args, trace from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.types import StateMap +from synapse.types import StateMap, StrCollection from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache @@ -197,7 +196,7 @@ class StateHandler: async def compute_state_after_events( self, room_id: str, - event_ids: Collection[str], + event_ids: StrCollection, state_filter: Optional[StateFilter] = None, await_full_state: bool = True, ) -> StateMap[str]: @@ -231,7 +230,7 @@ class StateHandler: return await ret.get_state(self._state_storage_controller, state_filter) async def get_current_user_ids_in_room( - self, room_id: str, latest_event_ids: Collection[str] + self, room_id: str, latest_event_ids: StrCollection ) -> Set[str]: """ Get the users IDs who are currently in a room. @@ -256,7 +255,7 @@ class StateHandler: return await self.store.get_joined_user_ids_from_state(room_id, state) async def get_hosts_in_room_at_events( - self, room_id: str, event_ids: Collection[str] + self, room_id: str, event_ids: StrCollection ) -> FrozenSet[str]: """Get the hosts that were in a room at the given event ids @@ -470,7 +469,7 @@ class StateHandler: @trace @measure_func() async def resolve_state_groups_for_events( - self, room_id: str, event_ids: Collection[str], await_full_state: bool = True + self, room_id: str, event_ids: StrCollection, await_full_state: bool = True ) -> _StateCacheEntry: """Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. @@ -882,7 +881,7 @@ class StateResolutionStore: store: "DataStore" def get_events( - self, event_ids: Collection[str], allow_rejected: bool = False + self, event_ids: StrCollection, allow_rejected: bool = False ) -> Awaitable[Dict[str, EventBase]]: """Get events from the database diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 500e384695..c76a2f082e 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -17,7 +17,6 @@ import logging from typing import ( Awaitable, Callable, - Collection, Dict, Iterable, List, @@ -32,7 +31,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersion from synapse.events import EventBase -from synapse.types import MutableStateMap, StateMap +from synapse.types import MutableStateMap, StateMap, StrCollection logger = logging.getLogger(__name__) @@ -45,7 +44,7 @@ async def resolve_events_with_store( room_version: RoomVersion, state_sets: Sequence[StateMap[str]], event_map: Optional[Dict[str, EventBase]], - state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]], + state_map_factory: Callable[[StrCollection], Awaitable[Dict[str, EventBase]]], ) -> StateMap[str]: """ Args: diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 44c49274a9..1752f95db8 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -19,7 +19,6 @@ from typing import ( Any, Awaitable, Callable, - Collection, Dict, Generator, Iterable, @@ -39,7 +38,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersion from synapse.events import EventBase -from synapse.types import MutableStateMap, StateMap +from synapse.types import MutableStateMap, StateMap, StrCollection logger = logging.getLogger(__name__) @@ -56,7 +55,7 @@ class StateResolutionStore(Protocol): # This is usually synapse.state.StateResolutionStore, but it's replaced with a # TestStateResolutionStore in tests. def get_events( - self, event_ids: Collection[str], allow_rejected: bool = False + self, event_ids: StrCollection, allow_rejected: bool = False ) -> Awaitable[Dict[str, EventBase]]: ... @@ -366,7 +365,7 @@ async def _get_auth_chain_difference( union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:]) intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:]) - auth_difference_unpersisted_part: Collection[str] = union - intersection + auth_difference_unpersisted_part: StrCollection = union - intersection else: auth_difference_unpersisted_part = () state_sets_ids = [set(state_set.values()) for state_set in state_sets] diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index e97b844dfa..16170e0436 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -764,3 +764,14 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke } return list(results.values()) + + async def get_last_seen_for_user_id(self, user_id: str) -> Optional[int]: + """Get the last seen timestamp for a user, if we have it.""" + + return await self.db_pool.simple_select_one_onecol( + table="user_ips", + keyvalues={"user_id": user_id}, + retcol="MAX(last_seen)", + allow_none=True, + desc="get_last_seen_for_user_id", + ) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index fab7008a8f..09de8f55e2 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -47,7 +47,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.types import JsonDict, StrCollection +from synapse.types import JsonDict, StrCollection, StrSequence from synapse.util import json_encoder from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache @@ -1179,7 +1179,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) @cached(max_entries=5000, iterable=True) - async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]: + async def get_latest_event_ids_in_room(self, room_id: str) -> StrSequence: return await self.db_pool.simple_select_onecol( table="event_forward_extremities", keyvalues={"room_id": room_id}, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 7e85b73e8e..cc964604e2 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -16,7 +16,7 @@ import logging import random import re -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast import attr @@ -192,8 +192,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ) @cached() - async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]: - """Deprecated: use get_userinfo_by_id instead""" + async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]: + """Returns info about the user account, if it exists.""" def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: # We could technically use simple_select_one here, but it would not perform @@ -202,16 +202,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): txn.execute( """ SELECT - name, password_hash, is_guest, admin, consent_version, consent_ts, + name, is_guest, admin, consent_version, consent_ts, consent_server_notice_sent, appservice_id, creation_ts, user_type, deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned, COALESCE(approved, TRUE) AS approved, - COALESCE(locked, FALSE) AS locked, last_seen_ts + COALESCE(locked, FALSE) AS locked FROM users - LEFT JOIN ( - SELECT user_id, MAX(last_seen) AS last_seen_ts - FROM user_ips GROUP BY user_id - ) ls ON users.name = ls.user_id WHERE name = ? """, (user_id,), @@ -228,51 +224,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="get_user_by_id", func=get_user_by_id_txn, ) - - if row is not None: - # If we're using SQLite our boolean values will be integers. Because we - # present some of this data as is to e.g. server admins via REST APIs, we - # want to make sure we're returning the right type of data. - # Note: when adding a column name to this list, be wary of NULLable columns, - # since NULL values will be turned into False. - boolean_columns = [ - "admin", - "deactivated", - "shadow_banned", - "approved", - "locked", - ] - for column in boolean_columns: - row[column] = bool(row[column]) - - return row - - async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: - """Get a UserInfo object for a user by user ID. - - Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed, - this method should be cached. - - Args: - user_id: The user to fetch user info for. - Returns: - `UserInfo` object if user found, otherwise `None`. - """ - user_data = await self.get_user_by_id(user_id) - if not user_data: + if row is None: return None + return UserInfo( - appservice_id=user_data["appservice_id"], - consent_server_notice_sent=user_data["consent_server_notice_sent"], - consent_version=user_data["consent_version"], - creation_ts=user_data["creation_ts"], - is_admin=bool(user_data["admin"]), - is_deactivated=bool(user_data["deactivated"]), - is_guest=bool(user_data["is_guest"]), - is_shadow_banned=bool(user_data["shadow_banned"]), - user_id=UserID.from_string(user_data["name"]), - user_type=user_data["user_type"], - last_seen_ts=user_data["last_seen_ts"], + appservice_id=row["appservice_id"], + consent_server_notice_sent=row["consent_server_notice_sent"], + consent_version=row["consent_version"], + consent_ts=row["consent_ts"], + creation_ts=row["creation_ts"], + is_admin=bool(row["admin"]), + is_deactivated=bool(row["deactivated"]), + is_guest=bool(row["is_guest"]), + is_shadow_banned=bool(row["shadow_banned"]), + user_id=UserID.from_string(row["name"]), + user_type=row["user_type"], + approved=bool(row["approved"]), + locked=bool(row["locked"]), ) async def is_trial_user(self, user_id: str) -> bool: @@ -290,10 +258,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): now = self._clock.time_msec() days = self.config.server.mau_appservice_trial_days.get( - info["appservice_id"], self.config.server.mau_trial_days + info.appservice_id, self.config.server.mau_trial_days ) trial_duration_ms = days * 24 * 60 * 60 * 1000 - is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms + is_trial = (now - info.creation_ts * 1000) < trial_duration_ms return is_trial @cached() @@ -2312,6 +2280,26 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): return next_id + async def set_device_for_refresh_token( + self, user_id: str, old_device_id: str, device_id: str + ) -> None: + """Moves refresh tokens from old device to current device + + Args: + user_id: The user of the devices. + old_device_id: The old device. + device_id: The new device ID. + Returns: + None + """ + + await self.db_pool.simple_update( + "refresh_tokens", + keyvalues={"user_id": user_id, "device_id": old_device_id}, + updatevalues={"device_id": device_id}, + desc="set_device_for_refresh_token", + ) + def _set_device_for_access_token_txn( self, txn: LoggingTransaction, token: str, device_id: str ) -> str: diff --git a/synapse/storage/databases/main/task_scheduler.py b/synapse/storage/databases/main/task_scheduler.py index 9ab120eea9..5c5372a825 100644 --- a/synapse/storage/databases/main/task_scheduler.py +++ b/synapse/storage/databases/main/task_scheduler.py @@ -53,6 +53,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore): resource_id: Optional[str] = None, statuses: Optional[List[TaskStatus]] = None, max_timestamp: Optional[int] = None, + limit: Optional[int] = None, ) -> List[ScheduledTask]: """Get a list of scheduled tasks from the DB. @@ -62,6 +63,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore): statuses: Limit the returned tasks to the specific statuses max_timestamp: Limit the returned tasks to the ones that have a timestamp inferior to the specified one + limit: Only return `limit` number of rows if set. Returns: a list of `ScheduledTask`, ordered by increasing timestamps """ @@ -94,6 +96,10 @@ class TaskSchedulerWorkerStore(SQLBaseStore): sql = sql + " ORDER BY timestamp" + if limit is not None: + sql += " LIMIT ?" + args.append(limit) + txn.execute(sql, args) return self.db_pool.cursor_to_dict(txn) diff --git a/synapse/storage/schema/main/delta/82/02_scheduled_tasks_index.sql b/synapse/storage/schema/main/delta/82/02_scheduled_tasks_index.sql new file mode 100644 index 0000000000..6b90275139 --- /dev/null +++ b/synapse/storage/schema/main/delta/82/02_scheduled_tasks_index.sql @@ -0,0 +1,16 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE INDEX IF NOT EXISTS scheduled_tasks_timestamp ON scheduled_tasks(timestamp); diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 488714f60c..76b0e3e694 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -933,33 +933,37 @@ def get_verify_key_from_cross_signing_key( @attr.s(auto_attribs=True, frozen=True, slots=True) class UserInfo: - """Holds information about a user. Result of get_userinfo_by_id. + """Holds information about a user. Result of get_user_by_id. Attributes: user_id: ID of the user. appservice_id: Application service ID that created this user. consent_server_notice_sent: Version of policy documents the user has been sent. consent_version: Version of policy documents the user has consented to. + consent_ts: Time the user consented creation_ts: Creation timestamp of the user. is_admin: True if the user is an admin. is_deactivated: True if the user has been deactivated. is_guest: True if the user is a guest user. is_shadow_banned: True if the user has been shadow-banned. user_type: User type (None for normal user, 'support' and 'bot' other options). - last_seen_ts: Last activity timestamp of the user. + approved: If the user has been "approved" to register on the server. + locked: Whether the user's account has been locked """ user_id: UserID appservice_id: Optional[int] consent_server_notice_sent: Optional[str] consent_version: Optional[str] + consent_ts: Optional[int] user_type: Optional[str] creation_ts: int is_admin: bool is_deactivated: bool is_guest: bool is_shadow_banned: bool - last_seen_ts: Optional[int] + approved: bool + locked: bool class UserProfile(TypedDict): diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py index b7de201bde..caf13b3474 100644 --- a/synapse/util/task_scheduler.py +++ b/synapse/util/task_scheduler.py @@ -15,12 +15,14 @@ import logging from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set, Tuple -from prometheus_client import Gauge - from twisted.python.failure import Failure from synapse.logging.context import nested_logging_context -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics import LaterGauge +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.types import JsonMapping, ScheduledTask, TaskStatus from synapse.util.stringutils import random_string @@ -30,12 +32,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -running_tasks_gauge = Gauge( - "synapse_scheduler_running_tasks", - "The number of concurrent running tasks handled by the TaskScheduler", -) - - class TaskScheduler: """ This is a simple task sheduler aimed at resumable tasks: usually we use `run_in_background` @@ -70,6 +66,8 @@ class TaskScheduler: # Precision of the scheduler, evaluation of tasks to run will only happen # every `SCHEDULE_INTERVAL_MS` ms SCHEDULE_INTERVAL_MS = 1 * 60 * 1000 # 1mn + # How often to clean up old tasks. + CLEANUP_INTERVAL_MS = 30 * 60 * 1000 # Time before a complete or failed task is deleted from the DB KEEP_TASKS_FOR_MS = 7 * 24 * 60 * 60 * 1000 # 1 week # Maximum number of tasks that can run at the same time @@ -92,13 +90,25 @@ class TaskScheduler: ] = {} self._run_background_tasks = hs.config.worker.run_background_tasks + # Flag to make sure we only try and launch new tasks once at a time. + self._launching_new_tasks = False + if self._run_background_tasks: self._clock.looping_call( - run_as_background_process, + self._launch_scheduled_tasks, TaskScheduler.SCHEDULE_INTERVAL_MS, - "handle_scheduled_tasks", - self._handle_scheduled_tasks, ) + self._clock.looping_call( + self._clean_scheduled_tasks, + TaskScheduler.SCHEDULE_INTERVAL_MS, + ) + + LaterGauge( + "synapse_scheduler_running_tasks", + "The number of concurrent running tasks handled by the TaskScheduler", + labels=None, + caller=lambda: len(self._running_tasks), + ) def register_action( self, @@ -234,6 +244,7 @@ class TaskScheduler: resource_id: Optional[str] = None, statuses: Optional[List[TaskStatus]] = None, max_timestamp: Optional[int] = None, + limit: Optional[int] = None, ) -> List[ScheduledTask]: """Get a list of tasks. Returns all the tasks if no args is provided. @@ -247,6 +258,7 @@ class TaskScheduler: statuses: Limit the returned tasks to the specific statuses max_timestamp: Limit the returned tasks to the ones that have a timestamp inferior to the specified one + limit: Only return `limit` number of rows if set. Returns A list of `ScheduledTask`, ordered by increasing timestamps @@ -256,6 +268,7 @@ class TaskScheduler: resource_id=resource_id, statuses=statuses, max_timestamp=max_timestamp, + limit=limit, ) async def delete_task(self, id: str) -> None: @@ -273,34 +286,58 @@ class TaskScheduler: raise Exception(f"Task {id} is currently ACTIVE and can't be deleted") await self._store.delete_scheduled_task(id) - async def _handle_scheduled_tasks(self) -> None: - """Main loop taking care of launching tasks and cleaning up old ones.""" - await self._launch_scheduled_tasks() - await self._clean_scheduled_tasks() + def launch_task_by_id(self, id: str) -> None: + """Try launching the task with the given ID.""" + # Don't bother trying to launch new tasks if we're already at capacity. + if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS: + return + run_as_background_process("launch_task_by_id", self._launch_task_by_id, id) + + async def _launch_task_by_id(self, id: str) -> None: + """Helper async function for `launch_task_by_id`.""" + task = await self.get_task(id) + if task: + await self._launch_task(task) + + @wrap_as_background_process("launch_scheduled_tasks") async def _launch_scheduled_tasks(self) -> None: """Retrieve and launch scheduled tasks that should be running at that time.""" - for task in await self.get_tasks(statuses=[TaskStatus.ACTIVE]): - await self._launch_task(task) - for task in await self.get_tasks( - statuses=[TaskStatus.SCHEDULED], max_timestamp=self._clock.time_msec() - ): - await self._launch_task(task) + # Don't bother trying to launch new tasks if we're already at capacity. + if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS: + return - running_tasks_gauge.set(len(self._running_tasks)) + if self._launching_new_tasks: + return + self._launching_new_tasks = True + + try: + for task in await self.get_tasks( + statuses=[TaskStatus.ACTIVE], limit=self.MAX_CONCURRENT_RUNNING_TASKS + ): + await self._launch_task(task) + for task in await self.get_tasks( + statuses=[TaskStatus.SCHEDULED], + max_timestamp=self._clock.time_msec(), + limit=self.MAX_CONCURRENT_RUNNING_TASKS, + ): + await self._launch_task(task) + + finally: + self._launching_new_tasks = False + + @wrap_as_background_process("clean_scheduled_tasks") async def _clean_scheduled_tasks(self) -> None: """Clean old complete or failed jobs to avoid clutter the DB.""" + now = self._clock.time_msec() for task in await self._store.get_scheduled_tasks( - statuses=[TaskStatus.FAILED, TaskStatus.COMPLETE] + statuses=[TaskStatus.FAILED, TaskStatus.COMPLETE], + max_timestamp=now - TaskScheduler.KEEP_TASKS_FOR_MS, ): # FAILED and COMPLETE tasks should never be running assert task.id not in self._running_tasks - if ( - self._clock.time_msec() - > task.timestamp + TaskScheduler.KEEP_TASKS_FOR_MS - ): - await self._store.delete_scheduled_task(task.id) + await self._store.delete_scheduled_task(task.id) async def _launch_task(self, task: ScheduledTask) -> None: """Launch a scheduled task now. @@ -339,6 +376,9 @@ class TaskScheduler: ) self._running_tasks.remove(task.id) + # Try launch a new task since we've finished with this one. + self._clock.call_later(1, self._launch_scheduled_tasks) + if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS: return @@ -355,4 +395,4 @@ class TaskScheduler: self._running_tasks.add(task.id) await self.update_task(task.id, status=TaskStatus.ACTIVE) - run_as_background_process(task.action, wrapper) + run_as_background_process(f"task-{task.action}", wrapper) diff --git a/synapse/visibility.py b/synapse/visibility.py index eac10f6438..f15fdd8314 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -36,7 +36,7 @@ from synapse.events.utils import prune_event from synapse.logging.opentracing import trace from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore -from synapse.types import RetentionPolicy, StateMap, get_domain_from_id +from synapse.types import RetentionPolicy, StateMap, StrCollection, get_domain_from_id from synapse.types.state import StateFilter from synapse.util import Clock @@ -150,12 +150,12 @@ async def filter_events_for_client( async def filter_event_for_clients_with_state( store: DataStore, - user_ids: Collection[str], + user_ids: StrCollection, event: EventBase, context: EventContext, is_peeking: bool = False, filter_send_to_client: bool = True, -) -> Collection[str]: +) -> StrCollection: """ Checks to see if an event is visible to the users in the list at the time of the event. diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index dcd01d5688..e00d7215df 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -188,8 +188,11 @@ class AuthTestCase(unittest.HomeserverTestCase): ) app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) - # This just needs to return a truth-y value. - self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False}) + + class FakeUserInfo: + is_guest = False + + self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo()) self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) @@ -341,7 +344,10 @@ class AuthTestCase(unittest.HomeserverTestCase): ) def test_get_guest_user_from_macaroon(self) -> None: - self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True}) + class FakeUserInfo: + is_guest = True + + self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo()) self.store.get_user_by_access_token = AsyncMock(return_value=None) user_id = "@baldrick:matrix.org" diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 79d327499b..d4ed068357 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -461,6 +461,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase): self.message_handler = hs.get_device_message_handler() self.registration = hs.get_registration_handler() self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() self.store = hs.get_datastores().main return hs @@ -487,11 +488,12 @@ class DehydrationTestCase(unittest.HomeserverTestCase): self.assertEqual(device_data, {"device_data": {"foo": "bar"}}) # Create a new login for the user and dehydrated the device - device_id, access_token, _expiration_time, _refresh_token = self.get_success( + device_id, access_token, _expiration_time, refresh_token = self.get_success( self.registration.register_device( user_id=user_id, device_id=None, initial_display_name="new device", + should_issue_refresh_token=True, ) ) @@ -522,6 +524,12 @@ class DehydrationTestCase(unittest.HomeserverTestCase): self.assertEqual(user_info.device_id, retrieved_device_id) + # make sure the user device has the refresh token + assert refresh_token is not None + self.get_success( + self.auth_handler.refresh_token(refresh_token, 5 * 60 * 1000, 5 * 60 * 1000) + ) + # make sure the device has the display name that was set from the login res = self.get_success(self.handler.get_device(user_id, retrieved_device_id)) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 95c9792d54..0cca34d355 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -16,7 +16,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import UserTypes from synapse.api.errors import ThreepidValidationError from synapse.server import HomeServer -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, UserID, UserInfo from synapse.util import Clock from tests.unittest import HomeserverTestCase, override_config @@ -35,24 +35,22 @@ class RegistrationStoreTestCase(HomeserverTestCase): self.get_success(self.store.register_user(self.user_id, self.pwhash)) self.assertEqual( - { + UserInfo( # TODO(paul): Surely this field should be 'user_id', not 'name' - "name": self.user_id, - "password_hash": self.pwhash, - "admin": 0, - "is_guest": 0, - "consent_version": None, - "consent_ts": None, - "consent_server_notice_sent": None, - "appservice_id": None, - "creation_ts": 0, - "user_type": None, - "deactivated": 0, - "locked": 0, - "shadow_banned": 0, - "approved": 1, - "last_seen_ts": None, - }, + user_id=UserID.from_string(self.user_id), + is_admin=False, + is_guest=False, + consent_server_notice_sent=None, + consent_ts=None, + consent_version=None, + appservice_id=None, + creation_ts=0, + user_type=None, + is_deactivated=False, + locked=False, + is_shadow_banned=False, + approved=True, + ), (self.get_success(self.store.get_user_by_id(self.user_id))), ) @@ -65,9 +63,11 @@ class RegistrationStoreTestCase(HomeserverTestCase): user = self.get_success(self.store.get_user_by_id(self.user_id)) assert user - self.assertEqual(user["consent_version"], "1") - self.assertGreater(user["consent_ts"], before_consent) - self.assertLess(user["consent_ts"], self.clock.time_msec()) + self.assertEqual(user.consent_version, "1") + self.assertIsNotNone(user.consent_ts) + assert user.consent_ts is not None + self.assertGreater(user.consent_ts, before_consent) + self.assertLess(user.consent_ts, self.clock.time_msec()) def test_add_tokens(self) -> None: self.get_success(self.store.register_user(self.user_id, self.pwhash)) @@ -215,7 +215,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase): user = self.get_success(self.store.get_user_by_id(self.user_id)) assert user is not None - self.assertTrue(user["approved"]) + self.assertTrue(user.approved) approved = self.get_success(self.store.is_user_approved(self.user_id)) self.assertTrue(approved) @@ -228,7 +228,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase): user = self.get_success(self.store.get_user_by_id(self.user_id)) assert user is not None - self.assertFalse(user["approved"]) + self.assertFalse(user.approved) approved = self.get_success(self.store.is_user_approved(self.user_id)) self.assertFalse(approved) @@ -248,7 +248,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase): user = self.get_success(self.store.get_user_by_id(self.user_id)) self.assertIsNotNone(user) assert user is not None - self.assertEqual(user["approved"], 1) + self.assertEqual(user.approved, 1) approved = self.get_success(self.store.is_user_approved(self.user_id)) self.assertTrue(approved)