mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-20 19:10:45 +03:00
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
This commit is contained in:
commit
1e0b96f1a4
43 changed files with 341 additions and 242 deletions
1
changelog.d/16288.bugfix
Normal file
1
changelog.d/16288.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix bug introduced in Synapse 1.49.0 when using dehydrated devices ([MSC2697](https://github.com/matrix-org/matrix-spec-proposals/pull/2697)) and refresh tokens. Contributed by Hanadi.
|
1
changelog.d/16301.misc
Normal file
1
changelog.d/16301.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Improve type hints.
|
1
changelog.d/16304.doc
Normal file
1
changelog.d/16304.doc
Normal file
|
@ -0,0 +1 @@
|
|||
Link to the Alpine Linux community package for Synapse.
|
1
changelog.d/16313.misc
Normal file
1
changelog.d/16313.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Delete device messages asynchronously and in staged batches using the task scheduler.
|
1
changelog.d/16314.misc
Normal file
1
changelog.d/16314.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Remove a reference cycle for in background processes.
|
1
changelog.d/16316.misc
Normal file
1
changelog.d/16316.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Refactor `get_user_by_id`.
|
1
changelog.d/16318.misc
Normal file
1
changelog.d/16318.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Speed up task to delete to-device messages.
|
|
@ -155,6 +155,14 @@ sudo pip uninstall py-bcrypt
|
|||
sudo pip install py-bcrypt
|
||||
```
|
||||
|
||||
#### Alpine Linux
|
||||
|
||||
6543 maintains [Synapse packages for Alpine Linux](https://pkgs.alpinelinux.org/packages?name=synapse&branch=edge) in the community repository. Install with:
|
||||
|
||||
```sh
|
||||
sudo apk add synapse
|
||||
```
|
||||
|
||||
#### Void Linux
|
||||
|
||||
Synapse can be found in the void repositories as
|
||||
|
|
|
@ -268,7 +268,7 @@ class InternalAuth(BaseAuth):
|
|||
stored_user = await self.store.get_user_by_id(user_id)
|
||||
if not stored_user:
|
||||
raise InvalidClientTokenError("Unknown user_id %s" % user_id)
|
||||
if not stored_user["is_guest"]:
|
||||
if not stored_user.is_guest:
|
||||
raise InvalidClientTokenError(
|
||||
"Guest access token used for regular user"
|
||||
)
|
||||
|
|
|
@ -300,7 +300,7 @@ class MSC3861DelegatedAuth(BaseAuth):
|
|||
user_id = UserID(username, self._hostname)
|
||||
|
||||
# First try to find a user from the username claim
|
||||
user_info = await self.store.get_userinfo_by_id(user_id=user_id.to_string())
|
||||
user_info = await self.store.get_user_by_id(user_id=user_id.to_string())
|
||||
if user_info is None:
|
||||
# If the user does not exist, we should create it on the fly
|
||||
# TODO: we could use SCIM to provision users ahead of time and listen
|
||||
|
|
|
@ -27,9 +27,7 @@ from typing import (
|
|||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
NoReturn,
|
||||
Optional,
|
||||
|
@ -76,7 +74,7 @@ from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_
|
|||
from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
|
||||
load_legacy_third_party_event_rules,
|
||||
)
|
||||
from synapse.types import ISynapseReactor
|
||||
from synapse.types import ISynapseReactor, StrCollection
|
||||
from synapse.util import SYNAPSE_VERSION
|
||||
from synapse.util.caches.lrucache import setup_expire_lru_cache_entries
|
||||
from synapse.util.daemonize import daemonize_process
|
||||
|
@ -278,7 +276,7 @@ def register_start(
|
|||
reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
|
||||
|
||||
|
||||
def listen_metrics(bind_addresses: Iterable[str], port: int) -> None:
|
||||
def listen_metrics(bind_addresses: StrCollection, port: int) -> None:
|
||||
"""
|
||||
Start Prometheus metrics server.
|
||||
"""
|
||||
|
@ -315,7 +313,7 @@ def _set_prometheus_client_use_created_metrics(new_value: bool) -> None:
|
|||
|
||||
|
||||
def listen_manhole(
|
||||
bind_addresses: Collection[str],
|
||||
bind_addresses: StrCollection,
|
||||
port: int,
|
||||
manhole_settings: ManholeConfig,
|
||||
manhole_globals: dict,
|
||||
|
@ -339,7 +337,7 @@ def listen_manhole(
|
|||
|
||||
|
||||
def listen_tcp(
|
||||
bind_addresses: Collection[str],
|
||||
bind_addresses: StrCollection,
|
||||
port: int,
|
||||
factory: ServerFactory,
|
||||
reactor: IReactorTCP = reactor,
|
||||
|
@ -448,7 +446,7 @@ def listen_http(
|
|||
|
||||
|
||||
def listen_ssl(
|
||||
bind_addresses: Collection[str],
|
||||
bind_addresses: StrCollection,
|
||||
port: int,
|
||||
factory: ServerFactory,
|
||||
context_factory: IOpenSSLContextFactory,
|
||||
|
|
|
@ -26,7 +26,6 @@ from textwrap import dedent
|
|||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
|
@ -384,7 +383,7 @@ class RootConfig:
|
|||
|
||||
config_classes: List[Type[Config]] = []
|
||||
|
||||
def __init__(self, config_files: Collection[str] = ()):
|
||||
def __init__(self, config_files: StrSequence = ()):
|
||||
# Capture absolute paths here, so we can reload config after we daemonize.
|
||||
self.config_files = [os.path.abspath(path) for path in config_files]
|
||||
|
||||
|
|
|
@ -25,7 +25,6 @@ from typing import (
|
|||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
|
@ -408,7 +407,7 @@ class EventBase(metaclass=abc.ABCMeta):
|
|||
def keys(self) -> Iterable[str]:
|
||||
return self._dict.keys()
|
||||
|
||||
def prev_event_ids(self) -> Sequence[str]:
|
||||
def prev_event_ids(self) -> List[str]:
|
||||
"""Returns the list of prev event IDs. The order matches the order
|
||||
specified in the event, though there is no meaning to it.
|
||||
|
||||
|
@ -553,7 +552,7 @@ class FrozenEventV2(EventBase):
|
|||
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
|
||||
return self._event_id
|
||||
|
||||
def prev_event_ids(self) -> Sequence[str]:
|
||||
def prev_event_ids(self) -> List[str]:
|
||||
"""Returns the list of prev event IDs. The order matches the order
|
||||
specified in the event, though there is no meaning to it.
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import attr
|
||||
from signedjson.types import SigningKey
|
||||
|
@ -28,7 +28,7 @@ from synapse.event_auth import auth_types_for_event
|
|||
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
|
||||
from synapse.state import StateHandler
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.types import EventID, JsonDict
|
||||
from synapse.types import EventID, JsonDict, StrCollection
|
||||
from synapse.types.state import StateFilter
|
||||
from synapse.util import Clock
|
||||
from synapse.util.stringutils import random_string
|
||||
|
@ -103,7 +103,7 @@ class EventBuilder:
|
|||
|
||||
async def build(
|
||||
self,
|
||||
prev_event_ids: Collection[str],
|
||||
prev_event_ids: StrCollection,
|
||||
auth_event_ids: Optional[List[str]],
|
||||
depth: Optional[int] = None,
|
||||
) -> EventBase:
|
||||
|
@ -136,7 +136,7 @@ class EventBuilder:
|
|||
|
||||
format_version = self.room_version.event_format
|
||||
# The types of auth/prev events changes between event versions.
|
||||
prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]]
|
||||
prev_events: Union[StrCollection, List[Tuple[str, Dict[str, str]]]]
|
||||
auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
|
||||
if format_version == EventFormatVersions.ROOM_V1_V2:
|
||||
auth_events = await self._store.add_event_hashes(auth_event_ids)
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections.abc
|
||||
from typing import Iterable, List, Type, Union, cast
|
||||
from typing import List, Type, Union, cast
|
||||
|
||||
import jsonschema
|
||||
from pydantic import Field, StrictBool, StrictStr
|
||||
|
@ -36,7 +36,7 @@ from synapse.events.utils import (
|
|||
from synapse.federation.federation_server import server_matches_acl_event
|
||||
from synapse.http.servlet import validate_json_object
|
||||
from synapse.rest.models import RequestBodyModel
|
||||
from synapse.types import EventID, JsonDict, RoomID, UserID
|
||||
from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID
|
||||
|
||||
|
||||
class EventValidator:
|
||||
|
@ -225,7 +225,7 @@ class EventValidator:
|
|||
|
||||
self._ensure_state_event(event)
|
||||
|
||||
def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None:
|
||||
def _ensure_strings(self, d: JsonDict, keys: StrCollection) -> None:
|
||||
for s in keys:
|
||||
if s not in d:
|
||||
raise SynapseError(400, "'%s' not in content" % (s,))
|
||||
|
|
|
@ -102,7 +102,7 @@ class AccountHandler:
|
|||
"""
|
||||
status = {"exists": False}
|
||||
|
||||
userinfo = await self._main_store.get_userinfo_by_id(user_id.to_string())
|
||||
userinfo = await self._main_store.get_user_by_id(user_id.to_string())
|
||||
|
||||
if userinfo is not None:
|
||||
status = {
|
||||
|
|
|
@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set
|
|||
|
||||
from synapse.api.constants import Direction, Membership
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
|
||||
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID, UserInfo
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -57,38 +57,30 @@ class AdminHandler:
|
|||
|
||||
async def get_user(self, user: UserID) -> Optional[JsonDict]:
|
||||
"""Function to get user details"""
|
||||
user_info_dict = await self._store.get_user_by_id(user.to_string())
|
||||
if user_info_dict is None:
|
||||
user_info: Optional[UserInfo] = await self._store.get_user_by_id(
|
||||
user.to_string()
|
||||
)
|
||||
if user_info is None:
|
||||
return None
|
||||
|
||||
# Restrict returned information to a known set of fields. This prevents additional
|
||||
# fields added to get_user_by_id from modifying Synapse's external API surface.
|
||||
user_info_to_return = {
|
||||
"name",
|
||||
"admin",
|
||||
"deactivated",
|
||||
"locked",
|
||||
"shadow_banned",
|
||||
"creation_ts",
|
||||
"appservice_id",
|
||||
"consent_server_notice_sent",
|
||||
"consent_version",
|
||||
"consent_ts",
|
||||
"user_type",
|
||||
"is_guest",
|
||||
"last_seen_ts",
|
||||
user_info_dict = {
|
||||
"name": user.to_string(),
|
||||
"admin": user_info.is_admin,
|
||||
"deactivated": user_info.is_deactivated,
|
||||
"locked": user_info.locked,
|
||||
"shadow_banned": user_info.is_shadow_banned,
|
||||
"creation_ts": user_info.creation_ts,
|
||||
"appservice_id": user_info.appservice_id,
|
||||
"consent_server_notice_sent": user_info.consent_server_notice_sent,
|
||||
"consent_version": user_info.consent_version,
|
||||
"consent_ts": user_info.consent_ts,
|
||||
"user_type": user_info.user_type,
|
||||
"is_guest": user_info.is_guest,
|
||||
}
|
||||
|
||||
if self._msc3866_enabled:
|
||||
# Only include the approved flag if support for MSC3866 is enabled.
|
||||
user_info_to_return.add("approved")
|
||||
|
||||
# Restrict returned keys to a known set.
|
||||
user_info_dict = {
|
||||
key: value
|
||||
for key, value in user_info_dict.items()
|
||||
if key in user_info_to_return
|
||||
}
|
||||
user_info_dict["approved"] = user_info.approved
|
||||
|
||||
# Add additional user metadata
|
||||
profile = await self._store.get_profileinfo(user)
|
||||
|
@ -105,6 +97,9 @@ class AdminHandler:
|
|||
user_info_dict["external_ids"] = external_ids
|
||||
user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())
|
||||
|
||||
last_seen_ts = await self._store.get_last_seen_for_user_id(user.to_string())
|
||||
user_info_dict["last_seen_ts"] = last_seen_ts
|
||||
|
||||
return user_info_dict
|
||||
|
||||
async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:
|
||||
|
|
|
@ -388,7 +388,8 @@ class DeviceWorkerHandler:
|
|||
"Trying handling device list state for partial join: not supported on workers."
|
||||
)
|
||||
|
||||
DEVICE_MSGS_DELETE_BATCH_LIMIT = 100
|
||||
DEVICE_MSGS_DELETE_BATCH_LIMIT = 1000
|
||||
DEVICE_MSGS_DELETE_SLEEP_MS = 1000
|
||||
|
||||
async def _delete_device_messages(
|
||||
self,
|
||||
|
@ -400,19 +401,19 @@ class DeviceWorkerHandler:
|
|||
device_id = task.params["device_id"]
|
||||
up_to_stream_id = task.params["up_to_stream_id"]
|
||||
|
||||
res = await self.store.delete_messages_for_device(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
up_to_stream_id=up_to_stream_id,
|
||||
limit=DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT,
|
||||
)
|
||||
# Delete the messages in batches to avoid too much DB load.
|
||||
while True:
|
||||
res = await self.store.delete_messages_for_device(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
up_to_stream_id=up_to_stream_id,
|
||||
limit=DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT,
|
||||
)
|
||||
|
||||
if res < DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT:
|
||||
return TaskStatus.COMPLETE, None, None
|
||||
else:
|
||||
# There is probably still device messages to be deleted, let's keep the task active and it will be run
|
||||
# again in a subsequent scheduler loop run (probably the next one, if not too many tasks are running).
|
||||
return TaskStatus.ACTIVE, None, None
|
||||
if res < DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT:
|
||||
return TaskStatus.COMPLETE, None, None
|
||||
|
||||
await self.clock.sleep(DeviceHandler.DEVICE_MSGS_DELETE_SLEEP_MS / 1000.0)
|
||||
|
||||
|
||||
class DeviceHandler(DeviceWorkerHandler):
|
||||
|
@ -758,12 +759,13 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
|
||||
# If the dehydrated device was successfully deleted (the device ID
|
||||
# matched the stored dehydrated device), then modify the access
|
||||
# token to use the dehydrated device's ID and copy the old device
|
||||
# display name to the dehydrated device, and destroy the old device
|
||||
# ID
|
||||
# token and refresh token to use the dehydrated device's ID and
|
||||
# copy the old device display name to the dehydrated device,
|
||||
# and destroy the old device ID
|
||||
old_device_id = await self.store.set_device_for_access_token(
|
||||
access_token, device_id
|
||||
)
|
||||
await self.store.set_device_for_refresh_token(user_id, old_device_id, device_id)
|
||||
old_device = await self.store.get_device(user_id, old_device_id)
|
||||
if old_device is None:
|
||||
raise errors.NotFoundError()
|
||||
|
|
|
@ -828,13 +828,13 @@ class EventCreationHandler:
|
|||
|
||||
u = await self.store.get_user_by_id(user_id)
|
||||
assert u is not None
|
||||
if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT):
|
||||
if u.user_type in (UserTypes.SUPPORT, UserTypes.BOT):
|
||||
# support and bot users are not required to consent
|
||||
return
|
||||
if u["appservice_id"] is not None:
|
||||
if u.appservice_id is not None:
|
||||
# users registered by an appservice are exempt
|
||||
return
|
||||
if u["consent_version"] == self.config.consent.user_consent_version:
|
||||
if u.consent_version == self.config.consent.user_consent_version:
|
||||
return
|
||||
|
||||
consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)
|
||||
|
|
|
@ -78,7 +78,7 @@ from synapse.http.replicationagent import ReplicationAgent
|
|||
from synapse.http.types import QueryParams
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.logging.opentracing import set_tag, start_active_span, tags
|
||||
from synapse.types import ISynapseReactor
|
||||
from synapse.types import ISynapseReactor, StrSequence
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.async_helpers import timeout_deferred
|
||||
|
||||
|
@ -108,10 +108,9 @@ RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValu
|
|||
# the value actually has to be a List, but List is invariant so we can't specify that
|
||||
# the entries can either be Lists or bytes.
|
||||
RawHeaderValue = Union[
|
||||
List[str],
|
||||
StrSequence,
|
||||
List[bytes],
|
||||
List[Union[str, bytes]],
|
||||
Tuple[str, ...],
|
||||
Tuple[bytes, ...],
|
||||
Tuple[Union[str, bytes], ...],
|
||||
]
|
||||
|
|
|
@ -18,7 +18,6 @@ import logging
|
|||
from http import HTTPStatus
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
|
@ -38,7 +37,7 @@ from twisted.web.server import Request
|
|||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http import redact_uri
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.types import JsonDict, RoomAlias, RoomID
|
||||
from synapse.types import JsonDict, RoomAlias, RoomID, StrCollection
|
||||
from synapse.util import json_decoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -340,7 +339,7 @@ def parse_string(
|
|||
name: str,
|
||||
default: str,
|
||||
*,
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> str:
|
||||
...
|
||||
|
@ -352,7 +351,7 @@ def parse_string(
|
|||
name: str,
|
||||
*,
|
||||
required: Literal[True],
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> str:
|
||||
...
|
||||
|
@ -365,7 +364,7 @@ def parse_string(
|
|||
*,
|
||||
default: Optional[str] = None,
|
||||
required: bool = False,
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> Optional[str]:
|
||||
...
|
||||
|
@ -376,7 +375,7 @@ def parse_string(
|
|||
name: str,
|
||||
default: Optional[str] = None,
|
||||
required: bool = False,
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
|
@ -485,7 +484,7 @@ def parse_enum(
|
|||
|
||||
def _parse_string_value(
|
||||
value: bytes,
|
||||
allowed_values: Optional[Iterable[str]],
|
||||
allowed_values: Optional[StrCollection],
|
||||
name: str,
|
||||
encoding: str,
|
||||
) -> str:
|
||||
|
@ -511,7 +510,7 @@ def parse_strings_from_args(
|
|||
args: Mapping[bytes, Sequence[bytes]],
|
||||
name: str,
|
||||
*,
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> Optional[List[str]]:
|
||||
...
|
||||
|
@ -523,7 +522,7 @@ def parse_strings_from_args(
|
|||
name: str,
|
||||
default: List[str],
|
||||
*,
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> List[str]:
|
||||
...
|
||||
|
@ -535,7 +534,7 @@ def parse_strings_from_args(
|
|||
name: str,
|
||||
*,
|
||||
required: Literal[True],
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> List[str]:
|
||||
...
|
||||
|
@ -548,7 +547,7 @@ def parse_strings_from_args(
|
|||
default: Optional[List[str]] = None,
|
||||
*,
|
||||
required: bool = False,
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> Optional[List[str]]:
|
||||
...
|
||||
|
@ -559,7 +558,7 @@ def parse_strings_from_args(
|
|||
name: str,
|
||||
default: Optional[List[str]] = None,
|
||||
required: bool = False,
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> Optional[List[str]]:
|
||||
"""
|
||||
|
@ -610,7 +609,7 @@ def parse_string_from_args(
|
|||
name: str,
|
||||
default: Optional[str] = None,
|
||||
*,
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> Optional[str]:
|
||||
...
|
||||
|
@ -623,7 +622,7 @@ def parse_string_from_args(
|
|||
default: Optional[str] = None,
|
||||
*,
|
||||
required: Literal[True],
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> str:
|
||||
...
|
||||
|
@ -635,7 +634,7 @@ def parse_string_from_args(
|
|||
name: str,
|
||||
default: Optional[str] = None,
|
||||
required: bool = False,
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> Optional[str]:
|
||||
...
|
||||
|
@ -646,7 +645,7 @@ def parse_string_from_args(
|
|||
name: str,
|
||||
default: Optional[str] = None,
|
||||
required: bool = False,
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
allowed_values: Optional[StrCollection] = None,
|
||||
encoding: str = "ascii",
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
|
@ -821,7 +820,7 @@ def parse_and_validate_json_object_from_request(
|
|||
return validate_json_object(content, model_type)
|
||||
|
||||
|
||||
def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
|
||||
def assert_params_in_dict(body: JsonDict, required: StrCollection) -> None:
|
||||
absent = []
|
||||
for k in required:
|
||||
if k not in body:
|
||||
|
|
|
@ -25,7 +25,6 @@ from typing import (
|
|||
Iterable,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
|
@ -49,6 +48,7 @@ import synapse.metrics._reactor_metrics # noqa: F401
|
|||
from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager
|
||||
from synapse.metrics._twisted_exposition import MetricsResource, generate_latest
|
||||
from synapse.metrics._types import Collector
|
||||
from synapse.types import StrSequence
|
||||
from synapse.util import SYNAPSE_VERSION
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -81,7 +81,7 @@ class LaterGauge(Collector):
|
|||
|
||||
name: str
|
||||
desc: str
|
||||
labels: Optional[Sequence[str]] = attr.ib(hash=False)
|
||||
labels: Optional[StrSequence] = attr.ib(hash=False)
|
||||
# callback: should either return a value (if there are no labels for this metric),
|
||||
# or dict mapping from a label tuple to a value
|
||||
caller: Callable[
|
||||
|
@ -143,8 +143,8 @@ class InFlightGauge(Generic[MetricsEntry], Collector):
|
|||
self,
|
||||
name: str,
|
||||
desc: str,
|
||||
labels: Sequence[str],
|
||||
sub_metrics: Sequence[str],
|
||||
labels: StrSequence,
|
||||
sub_metrics: StrSequence,
|
||||
):
|
||||
self.name = name
|
||||
self.desc = desc
|
||||
|
|
|
@ -322,13 +322,21 @@ class BackgroundProcessLoggingContext(LoggingContext):
|
|||
if instance_id is None:
|
||||
instance_id = id(self)
|
||||
super().__init__("%s-%s" % (name, instance_id))
|
||||
self._proc = _BackgroundProcess(name, self)
|
||||
self._proc: Optional[_BackgroundProcess] = _BackgroundProcess(name, self)
|
||||
|
||||
def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
|
||||
"""Log context has started running (again)."""
|
||||
|
||||
super().start(rusage)
|
||||
|
||||
if self._proc is None:
|
||||
logger.error(
|
||||
"Background process re-entered without a proc: %s",
|
||||
self.name,
|
||||
stack_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
# We've become active again so we make sure we're in the list of active
|
||||
# procs. (Note that "start" here means we've become active, as opposed
|
||||
# to starting for the first time.)
|
||||
|
@ -345,6 +353,14 @@ class BackgroundProcessLoggingContext(LoggingContext):
|
|||
|
||||
super().__exit__(type, value, traceback)
|
||||
|
||||
if self._proc is None:
|
||||
logger.error(
|
||||
"Background process exited without a proc: %s",
|
||||
self.name,
|
||||
stack_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
# The background process has finished. We explicitly remove and manually
|
||||
# update the metrics here so that if nothing is scraping metrics the set
|
||||
# doesn't infinitely grow.
|
||||
|
@ -352,3 +368,6 @@ class BackgroundProcessLoggingContext(LoggingContext):
|
|||
_background_processes_active_since_last_scrape.discard(self._proc)
|
||||
|
||||
self._proc.update_metrics()
|
||||
|
||||
# Set proc to None to break the reference cycle.
|
||||
self._proc = None
|
||||
|
|
|
@ -572,7 +572,7 @@ class ModuleApi:
|
|||
Returns:
|
||||
UserInfo object if a user was found, otherwise None
|
||||
"""
|
||||
return await self._store.get_userinfo_by_id(user_id)
|
||||
return await self._store.get_user_by_id(user_id)
|
||||
|
||||
async def get_user_by_req(
|
||||
self,
|
||||
|
@ -1878,7 +1878,7 @@ class AccountDataManager:
|
|||
raise TypeError(f"new_data must be a dict; got {type(new_data).__name__}")
|
||||
|
||||
# Ensure the user exists, so we don't just write to users that aren't there.
|
||||
if await self._store.get_userinfo_by_id(user_id) is None:
|
||||
if await self._store.get_user_by_id(user_id) is None:
|
||||
raise ValueError(f"User {user_id} does not exist on this server.")
|
||||
|
||||
await self._handler.add_account_data_for_user(user_id, data_type, new_data)
|
||||
|
|
|
@ -104,7 +104,7 @@ class _NotifierUserStream:
|
|||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
rooms: Collection[str],
|
||||
rooms: StrCollection,
|
||||
current_token: StreamToken,
|
||||
time_now_ms: int,
|
||||
):
|
||||
|
@ -457,7 +457,7 @@ class Notifier:
|
|||
stream_key: str,
|
||||
new_token: Union[int, RoomStreamToken],
|
||||
users: Optional[Collection[Union[str, UserID]]] = None,
|
||||
rooms: Optional[Collection[str]] = None,
|
||||
rooms: Optional[StrCollection] = None,
|
||||
) -> None:
|
||||
"""Used to inform listeners that something has happened event wise.
|
||||
|
||||
|
@ -529,7 +529,7 @@ class Notifier:
|
|||
user_id: str,
|
||||
timeout: int,
|
||||
callback: Callable[[StreamToken, StreamToken], Awaitable[T]],
|
||||
room_ids: Optional[Collection[str]] = None,
|
||||
room_ids: Optional[StrCollection] = None,
|
||||
from_token: StreamToken = StreamToken.START,
|
||||
) -> T:
|
||||
"""Wait until the callback returns a non empty response or the
|
||||
|
|
|
@ -672,14 +672,12 @@ class ReplicationCommandHandler:
|
|||
cmd.instance_name, cmd.lock_name, cmd.lock_key
|
||||
)
|
||||
|
||||
async def on_NEW_ACTIVE_TASK(
|
||||
def on_NEW_ACTIVE_TASK(
|
||||
self, conn: IReplicationConnection, cmd: NewActiveTaskCommand
|
||||
) -> None:
|
||||
"""Called when get a new NEW_ACTIVE_TASK command."""
|
||||
if self._task_scheduler:
|
||||
task = await self._task_scheduler.get_task(cmd.data)
|
||||
if task:
|
||||
await self._task_scheduler._launch_task(task)
|
||||
self._task_scheduler.launch_task_by_id(cmd.data)
|
||||
|
||||
def new_connection(self, connection: IReplicationConnection) -> None:
|
||||
"""Called when we have a new connection."""
|
||||
|
|
|
@ -20,14 +20,14 @@ from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar,
|
|||
|
||||
from synapse.api.errors import InteractiveAuthIncompleteError
|
||||
from synapse.api.urls import CLIENT_API_PREFIX
|
||||
from synapse.types import JsonDict
|
||||
from synapse.types import JsonDict, StrCollection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def client_patterns(
|
||||
path_regex: str,
|
||||
releases: Iterable[str] = ("r0", "v3"),
|
||||
releases: StrCollection = ("r0", "v3"),
|
||||
unstable: bool = True,
|
||||
v1: bool = False,
|
||||
) -> Iterable[Pattern]:
|
||||
|
|
|
@ -129,7 +129,7 @@ class ConsentResource(DirectServeHtmlResource):
|
|||
if u is None:
|
||||
raise NotFoundError("Unknown user")
|
||||
|
||||
has_consented = u["consent_version"] == version
|
||||
has_consented = u.consent_version == version
|
||||
userhmac = userhmac_bytes.decode("ascii")
|
||||
|
||||
try:
|
||||
|
|
|
@ -79,15 +79,15 @@ class ConsentServerNotices:
|
|||
if u is None:
|
||||
return
|
||||
|
||||
if u["is_guest"] and not self._send_to_guests:
|
||||
if u.is_guest and not self._send_to_guests:
|
||||
# don't send to guests
|
||||
return
|
||||
|
||||
if u["consent_version"] == self._current_consent_version:
|
||||
if u.consent_version == self._current_consent_version:
|
||||
# user has already consented
|
||||
return
|
||||
|
||||
if u["consent_server_notice_sent"] == self._current_consent_version:
|
||||
if u.consent_server_notice_sent == self._current_consent_version:
|
||||
# we've already sent a notice to the user
|
||||
return
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@ from typing import (
|
|||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
|
@ -49,7 +48,7 @@ from synapse.logging.opentracing import tag_args, trace
|
|||
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
|
||||
from synapse.state import v1, v2
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.types import StateMap
|
||||
from synapse.types import StateMap, StrCollection
|
||||
from synapse.types.state import StateFilter
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
|
@ -197,7 +196,7 @@ class StateHandler:
|
|||
async def compute_state_after_events(
|
||||
self,
|
||||
room_id: str,
|
||||
event_ids: Collection[str],
|
||||
event_ids: StrCollection,
|
||||
state_filter: Optional[StateFilter] = None,
|
||||
await_full_state: bool = True,
|
||||
) -> StateMap[str]:
|
||||
|
@ -231,7 +230,7 @@ class StateHandler:
|
|||
return await ret.get_state(self._state_storage_controller, state_filter)
|
||||
|
||||
async def get_current_user_ids_in_room(
|
||||
self, room_id: str, latest_event_ids: Collection[str]
|
||||
self, room_id: str, latest_event_ids: StrCollection
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Get the users IDs who are currently in a room.
|
||||
|
@ -256,7 +255,7 @@ class StateHandler:
|
|||
return await self.store.get_joined_user_ids_from_state(room_id, state)
|
||||
|
||||
async def get_hosts_in_room_at_events(
|
||||
self, room_id: str, event_ids: Collection[str]
|
||||
self, room_id: str, event_ids: StrCollection
|
||||
) -> FrozenSet[str]:
|
||||
"""Get the hosts that were in a room at the given event ids
|
||||
|
||||
|
@ -470,7 +469,7 @@ class StateHandler:
|
|||
@trace
|
||||
@measure_func()
|
||||
async def resolve_state_groups_for_events(
|
||||
self, room_id: str, event_ids: Collection[str], await_full_state: bool = True
|
||||
self, room_id: str, event_ids: StrCollection, await_full_state: bool = True
|
||||
) -> _StateCacheEntry:
|
||||
"""Given a list of event_ids this method fetches the state at each
|
||||
event, resolves conflicts between them and returns them.
|
||||
|
@ -882,7 +881,7 @@ class StateResolutionStore:
|
|||
store: "DataStore"
|
||||
|
||||
def get_events(
|
||||
self, event_ids: Collection[str], allow_rejected: bool = False
|
||||
self, event_ids: StrCollection, allow_rejected: bool = False
|
||||
) -> Awaitable[Dict[str, EventBase]]:
|
||||
"""Get events from the database
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@ import logging
|
|||
from typing import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
|
@ -32,7 +31,7 @@ from synapse.api.constants import EventTypes
|
|||
from synapse.api.errors import AuthError
|
||||
from synapse.api.room_versions import RoomVersion
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import MutableStateMap, StateMap
|
||||
from synapse.types import MutableStateMap, StateMap, StrCollection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -45,7 +44,7 @@ async def resolve_events_with_store(
|
|||
room_version: RoomVersion,
|
||||
state_sets: Sequence[StateMap[str]],
|
||||
event_map: Optional[Dict[str, EventBase]],
|
||||
state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]],
|
||||
state_map_factory: Callable[[StrCollection], Awaitable[Dict[str, EventBase]]],
|
||||
) -> StateMap[str]:
|
||||
"""
|
||||
Args:
|
||||
|
|
|
@ -19,7 +19,6 @@ from typing import (
|
|||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
|
@ -39,7 +38,7 @@ from synapse.api.constants import EventTypes
|
|||
from synapse.api.errors import AuthError
|
||||
from synapse.api.room_versions import RoomVersion
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import MutableStateMap, StateMap
|
||||
from synapse.types import MutableStateMap, StateMap, StrCollection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -56,7 +55,7 @@ class StateResolutionStore(Protocol):
|
|||
# This is usually synapse.state.StateResolutionStore, but it's replaced with a
|
||||
# TestStateResolutionStore in tests.
|
||||
def get_events(
|
||||
self, event_ids: Collection[str], allow_rejected: bool = False
|
||||
self, event_ids: StrCollection, allow_rejected: bool = False
|
||||
) -> Awaitable[Dict[str, EventBase]]:
|
||||
...
|
||||
|
||||
|
@ -366,7 +365,7 @@ async def _get_auth_chain_difference(
|
|||
union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
|
||||
intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
|
||||
|
||||
auth_difference_unpersisted_part: Collection[str] = union - intersection
|
||||
auth_difference_unpersisted_part: StrCollection = union - intersection
|
||||
else:
|
||||
auth_difference_unpersisted_part = ()
|
||||
state_sets_ids = [set(state_set.values()) for state_set in state_sets]
|
||||
|
|
|
@ -764,3 +764,14 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
|
|||
}
|
||||
|
||||
return list(results.values())
|
||||
|
||||
async def get_last_seen_for_user_id(self, user_id: str) -> Optional[int]:
|
||||
"""Get the last seen timestamp for a user, if we have it."""
|
||||
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="user_ips",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="MAX(last_seen)",
|
||||
allow_none=True,
|
||||
desc="get_last_seen_for_user_id",
|
||||
)
|
||||
|
|
|
@ -47,7 +47,7 @@ from synapse.storage.database import (
|
|||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.databases.main.signatures import SignatureWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
from synapse.types import JsonDict, StrCollection
|
||||
from synapse.types import JsonDict, StrCollection, StrSequence
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
|
@ -1179,7 +1179,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
)
|
||||
|
||||
@cached(max_entries=5000, iterable=True)
|
||||
async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]:
|
||||
async def get_latest_event_ids_in_room(self, room_id: str) -> StrSequence:
|
||||
return await self.db_pool.simple_select_onecol(
|
||||
table="event_forward_extremities",
|
||||
keyvalues={"room_id": room_id},
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import logging
|
||||
import random
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import attr
|
||||
|
||||
|
@ -192,8 +192,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
)
|
||||
|
||||
@cached()
|
||||
async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]:
|
||||
"""Deprecated: use get_userinfo_by_id instead"""
|
||||
async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
|
||||
"""Returns info about the user account, if it exists."""
|
||||
|
||||
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
|
||||
# We could technically use simple_select_one here, but it would not perform
|
||||
|
@ -202,16 +202,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
txn.execute(
|
||||
"""
|
||||
SELECT
|
||||
name, password_hash, is_guest, admin, consent_version, consent_ts,
|
||||
name, is_guest, admin, consent_version, consent_ts,
|
||||
consent_server_notice_sent, appservice_id, creation_ts, user_type,
|
||||
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
|
||||
COALESCE(approved, TRUE) AS approved,
|
||||
COALESCE(locked, FALSE) AS locked, last_seen_ts
|
||||
COALESCE(locked, FALSE) AS locked
|
||||
FROM users
|
||||
LEFT JOIN (
|
||||
SELECT user_id, MAX(last_seen) AS last_seen_ts
|
||||
FROM user_ips GROUP BY user_id
|
||||
) ls ON users.name = ls.user_id
|
||||
WHERE name = ?
|
||||
""",
|
||||
(user_id,),
|
||||
|
@ -228,51 +224,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
desc="get_user_by_id",
|
||||
func=get_user_by_id_txn,
|
||||
)
|
||||
|
||||
if row is not None:
|
||||
# If we're using SQLite our boolean values will be integers. Because we
|
||||
# present some of this data as is to e.g. server admins via REST APIs, we
|
||||
# want to make sure we're returning the right type of data.
|
||||
# Note: when adding a column name to this list, be wary of NULLable columns,
|
||||
# since NULL values will be turned into False.
|
||||
boolean_columns = [
|
||||
"admin",
|
||||
"deactivated",
|
||||
"shadow_banned",
|
||||
"approved",
|
||||
"locked",
|
||||
]
|
||||
for column in boolean_columns:
|
||||
row[column] = bool(row[column])
|
||||
|
||||
return row
|
||||
|
||||
async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
|
||||
"""Get a UserInfo object for a user by user ID.
|
||||
|
||||
Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed,
|
||||
this method should be cached.
|
||||
|
||||
Args:
|
||||
user_id: The user to fetch user info for.
|
||||
Returns:
|
||||
`UserInfo` object if user found, otherwise `None`.
|
||||
"""
|
||||
user_data = await self.get_user_by_id(user_id)
|
||||
if not user_data:
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
return UserInfo(
|
||||
appservice_id=user_data["appservice_id"],
|
||||
consent_server_notice_sent=user_data["consent_server_notice_sent"],
|
||||
consent_version=user_data["consent_version"],
|
||||
creation_ts=user_data["creation_ts"],
|
||||
is_admin=bool(user_data["admin"]),
|
||||
is_deactivated=bool(user_data["deactivated"]),
|
||||
is_guest=bool(user_data["is_guest"]),
|
||||
is_shadow_banned=bool(user_data["shadow_banned"]),
|
||||
user_id=UserID.from_string(user_data["name"]),
|
||||
user_type=user_data["user_type"],
|
||||
last_seen_ts=user_data["last_seen_ts"],
|
||||
appservice_id=row["appservice_id"],
|
||||
consent_server_notice_sent=row["consent_server_notice_sent"],
|
||||
consent_version=row["consent_version"],
|
||||
consent_ts=row["consent_ts"],
|
||||
creation_ts=row["creation_ts"],
|
||||
is_admin=bool(row["admin"]),
|
||||
is_deactivated=bool(row["deactivated"]),
|
||||
is_guest=bool(row["is_guest"]),
|
||||
is_shadow_banned=bool(row["shadow_banned"]),
|
||||
user_id=UserID.from_string(row["name"]),
|
||||
user_type=row["user_type"],
|
||||
approved=bool(row["approved"]),
|
||||
locked=bool(row["locked"]),
|
||||
)
|
||||
|
||||
async def is_trial_user(self, user_id: str) -> bool:
|
||||
|
@ -290,10 +258,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
|
||||
now = self._clock.time_msec()
|
||||
days = self.config.server.mau_appservice_trial_days.get(
|
||||
info["appservice_id"], self.config.server.mau_trial_days
|
||||
info.appservice_id, self.config.server.mau_trial_days
|
||||
)
|
||||
trial_duration_ms = days * 24 * 60 * 60 * 1000
|
||||
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
|
||||
is_trial = (now - info.creation_ts * 1000) < trial_duration_ms
|
||||
return is_trial
|
||||
|
||||
@cached()
|
||||
|
@ -2312,6 +2280,26 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
|
||||
return next_id
|
||||
|
||||
async def set_device_for_refresh_token(
|
||||
self, user_id: str, old_device_id: str, device_id: str
|
||||
) -> None:
|
||||
"""Moves refresh tokens from old device to current device
|
||||
|
||||
Args:
|
||||
user_id: The user of the devices.
|
||||
old_device_id: The old device.
|
||||
device_id: The new device ID.
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
await self.db_pool.simple_update(
|
||||
"refresh_tokens",
|
||||
keyvalues={"user_id": user_id, "device_id": old_device_id},
|
||||
updatevalues={"device_id": device_id},
|
||||
desc="set_device_for_refresh_token",
|
||||
)
|
||||
|
||||
def _set_device_for_access_token_txn(
|
||||
self, txn: LoggingTransaction, token: str, device_id: str
|
||||
) -> str:
|
||||
|
|
|
@ -53,6 +53,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
|
|||
resource_id: Optional[str] = None,
|
||||
statuses: Optional[List[TaskStatus]] = None,
|
||||
max_timestamp: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[ScheduledTask]:
|
||||
"""Get a list of scheduled tasks from the DB.
|
||||
|
||||
|
@ -62,6 +63,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
|
|||
statuses: Limit the returned tasks to the specific statuses
|
||||
max_timestamp: Limit the returned tasks to the ones that have
|
||||
a timestamp inferior to the specified one
|
||||
limit: Only return `limit` number of rows if set.
|
||||
|
||||
Returns: a list of `ScheduledTask`, ordered by increasing timestamps
|
||||
"""
|
||||
|
@ -94,6 +96,10 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
|
|||
|
||||
sql = sql + " ORDER BY timestamp"
|
||||
|
||||
if limit is not None:
|
||||
sql += " LIMIT ?"
|
||||
args.append(limit)
|
||||
|
||||
txn.execute(sql, args)
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
/* Copyright 2023 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE INDEX IF NOT EXISTS scheduled_tasks_timestamp ON scheduled_tasks(timestamp);
|
|
@ -933,33 +933,37 @@ def get_verify_key_from_cross_signing_key(
|
|||
|
||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
||||
class UserInfo:
|
||||
"""Holds information about a user. Result of get_userinfo_by_id.
|
||||
"""Holds information about a user. Result of get_user_by_id.
|
||||
|
||||
Attributes:
|
||||
user_id: ID of the user.
|
||||
appservice_id: Application service ID that created this user.
|
||||
consent_server_notice_sent: Version of policy documents the user has been sent.
|
||||
consent_version: Version of policy documents the user has consented to.
|
||||
consent_ts: Time the user consented
|
||||
creation_ts: Creation timestamp of the user.
|
||||
is_admin: True if the user is an admin.
|
||||
is_deactivated: True if the user has been deactivated.
|
||||
is_guest: True if the user is a guest user.
|
||||
is_shadow_banned: True if the user has been shadow-banned.
|
||||
user_type: User type (None for normal user, 'support' and 'bot' other options).
|
||||
last_seen_ts: Last activity timestamp of the user.
|
||||
approved: If the user has been "approved" to register on the server.
|
||||
locked: Whether the user's account has been locked
|
||||
"""
|
||||
|
||||
user_id: UserID
|
||||
appservice_id: Optional[int]
|
||||
consent_server_notice_sent: Optional[str]
|
||||
consent_version: Optional[str]
|
||||
consent_ts: Optional[int]
|
||||
user_type: Optional[str]
|
||||
creation_ts: int
|
||||
is_admin: bool
|
||||
is_deactivated: bool
|
||||
is_guest: bool
|
||||
is_shadow_banned: bool
|
||||
last_seen_ts: Optional[int]
|
||||
approved: bool
|
||||
locked: bool
|
||||
|
||||
|
||||
class UserProfile(TypedDict):
|
||||
|
|
|
@ -15,12 +15,14 @@
|
|||
import logging
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from prometheus_client import Gauge
|
||||
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
from synapse.logging.context import nested_logging_context
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.metrics import LaterGauge
|
||||
from synapse.metrics.background_process_metrics import (
|
||||
run_as_background_process,
|
||||
wrap_as_background_process,
|
||||
)
|
||||
from synapse.types import JsonMapping, ScheduledTask, TaskStatus
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
|
@ -30,12 +32,6 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
running_tasks_gauge = Gauge(
|
||||
"synapse_scheduler_running_tasks",
|
||||
"The number of concurrent running tasks handled by the TaskScheduler",
|
||||
)
|
||||
|
||||
|
||||
class TaskScheduler:
|
||||
"""
|
||||
This is a simple task sheduler aimed at resumable tasks: usually we use `run_in_background`
|
||||
|
@ -70,6 +66,8 @@ class TaskScheduler:
|
|||
# Precision of the scheduler, evaluation of tasks to run will only happen
|
||||
# every `SCHEDULE_INTERVAL_MS` ms
|
||||
SCHEDULE_INTERVAL_MS = 1 * 60 * 1000 # 1mn
|
||||
# How often to clean up old tasks.
|
||||
CLEANUP_INTERVAL_MS = 30 * 60 * 1000
|
||||
# Time before a complete or failed task is deleted from the DB
|
||||
KEEP_TASKS_FOR_MS = 7 * 24 * 60 * 60 * 1000 # 1 week
|
||||
# Maximum number of tasks that can run at the same time
|
||||
|
@ -92,13 +90,25 @@ class TaskScheduler:
|
|||
] = {}
|
||||
self._run_background_tasks = hs.config.worker.run_background_tasks
|
||||
|
||||
# Flag to make sure we only try and launch new tasks once at a time.
|
||||
self._launching_new_tasks = False
|
||||
|
||||
if self._run_background_tasks:
|
||||
self._clock.looping_call(
|
||||
run_as_background_process,
|
||||
self._launch_scheduled_tasks,
|
||||
TaskScheduler.SCHEDULE_INTERVAL_MS,
|
||||
"handle_scheduled_tasks",
|
||||
self._handle_scheduled_tasks,
|
||||
)
|
||||
self._clock.looping_call(
|
||||
self._clean_scheduled_tasks,
|
||||
TaskScheduler.SCHEDULE_INTERVAL_MS,
|
||||
)
|
||||
|
||||
LaterGauge(
|
||||
"synapse_scheduler_running_tasks",
|
||||
"The number of concurrent running tasks handled by the TaskScheduler",
|
||||
labels=None,
|
||||
caller=lambda: len(self._running_tasks),
|
||||
)
|
||||
|
||||
def register_action(
|
||||
self,
|
||||
|
@ -234,6 +244,7 @@ class TaskScheduler:
|
|||
resource_id: Optional[str] = None,
|
||||
statuses: Optional[List[TaskStatus]] = None,
|
||||
max_timestamp: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[ScheduledTask]:
|
||||
"""Get a list of tasks. Returns all the tasks if no args is provided.
|
||||
|
||||
|
@ -247,6 +258,7 @@ class TaskScheduler:
|
|||
statuses: Limit the returned tasks to the specific statuses
|
||||
max_timestamp: Limit the returned tasks to the ones that have
|
||||
a timestamp inferior to the specified one
|
||||
limit: Only return `limit` number of rows if set.
|
||||
|
||||
Returns
|
||||
A list of `ScheduledTask`, ordered by increasing timestamps
|
||||
|
@ -256,6 +268,7 @@ class TaskScheduler:
|
|||
resource_id=resource_id,
|
||||
statuses=statuses,
|
||||
max_timestamp=max_timestamp,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
async def delete_task(self, id: str) -> None:
|
||||
|
@ -273,34 +286,58 @@ class TaskScheduler:
|
|||
raise Exception(f"Task {id} is currently ACTIVE and can't be deleted")
|
||||
await self._store.delete_scheduled_task(id)
|
||||
|
||||
async def _handle_scheduled_tasks(self) -> None:
|
||||
"""Main loop taking care of launching tasks and cleaning up old ones."""
|
||||
await self._launch_scheduled_tasks()
|
||||
await self._clean_scheduled_tasks()
|
||||
def launch_task_by_id(self, id: str) -> None:
|
||||
"""Try launching the task with the given ID."""
|
||||
# Don't bother trying to launch new tasks if we're already at capacity.
|
||||
if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
|
||||
return
|
||||
|
||||
run_as_background_process("launch_task_by_id", self._launch_task_by_id, id)
|
||||
|
||||
async def _launch_task_by_id(self, id: str) -> None:
|
||||
"""Helper async function for `launch_task_by_id`."""
|
||||
task = await self.get_task(id)
|
||||
if task:
|
||||
await self._launch_task(task)
|
||||
|
||||
@wrap_as_background_process("launch_scheduled_tasks")
|
||||
async def _launch_scheduled_tasks(self) -> None:
|
||||
"""Retrieve and launch scheduled tasks that should be running at that time."""
|
||||
for task in await self.get_tasks(statuses=[TaskStatus.ACTIVE]):
|
||||
await self._launch_task(task)
|
||||
for task in await self.get_tasks(
|
||||
statuses=[TaskStatus.SCHEDULED], max_timestamp=self._clock.time_msec()
|
||||
):
|
||||
await self._launch_task(task)
|
||||
# Don't bother trying to launch new tasks if we're already at capacity.
|
||||
if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
|
||||
return
|
||||
|
||||
running_tasks_gauge.set(len(self._running_tasks))
|
||||
if self._launching_new_tasks:
|
||||
return
|
||||
|
||||
self._launching_new_tasks = True
|
||||
|
||||
try:
|
||||
for task in await self.get_tasks(
|
||||
statuses=[TaskStatus.ACTIVE], limit=self.MAX_CONCURRENT_RUNNING_TASKS
|
||||
):
|
||||
await self._launch_task(task)
|
||||
for task in await self.get_tasks(
|
||||
statuses=[TaskStatus.SCHEDULED],
|
||||
max_timestamp=self._clock.time_msec(),
|
||||
limit=self.MAX_CONCURRENT_RUNNING_TASKS,
|
||||
):
|
||||
await self._launch_task(task)
|
||||
|
||||
finally:
|
||||
self._launching_new_tasks = False
|
||||
|
||||
@wrap_as_background_process("clean_scheduled_tasks")
|
||||
async def _clean_scheduled_tasks(self) -> None:
|
||||
"""Clean old complete or failed jobs to avoid clutter the DB."""
|
||||
now = self._clock.time_msec()
|
||||
for task in await self._store.get_scheduled_tasks(
|
||||
statuses=[TaskStatus.FAILED, TaskStatus.COMPLETE]
|
||||
statuses=[TaskStatus.FAILED, TaskStatus.COMPLETE],
|
||||
max_timestamp=now - TaskScheduler.KEEP_TASKS_FOR_MS,
|
||||
):
|
||||
# FAILED and COMPLETE tasks should never be running
|
||||
assert task.id not in self._running_tasks
|
||||
if (
|
||||
self._clock.time_msec()
|
||||
> task.timestamp + TaskScheduler.KEEP_TASKS_FOR_MS
|
||||
):
|
||||
await self._store.delete_scheduled_task(task.id)
|
||||
await self._store.delete_scheduled_task(task.id)
|
||||
|
||||
async def _launch_task(self, task: ScheduledTask) -> None:
|
||||
"""Launch a scheduled task now.
|
||||
|
@ -339,6 +376,9 @@ class TaskScheduler:
|
|||
)
|
||||
self._running_tasks.remove(task.id)
|
||||
|
||||
# Try launch a new task since we've finished with this one.
|
||||
self._clock.call_later(1, self._launch_scheduled_tasks)
|
||||
|
||||
if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
|
||||
return
|
||||
|
||||
|
@ -355,4 +395,4 @@ class TaskScheduler:
|
|||
|
||||
self._running_tasks.add(task.id)
|
||||
await self.update_task(task.id, status=TaskStatus.ACTIVE)
|
||||
run_as_background_process(task.action, wrapper)
|
||||
run_as_background_process(f"task-{task.action}", wrapper)
|
||||
|
|
|
@ -36,7 +36,7 @@ from synapse.events.utils import prune_event
|
|||
from synapse.logging.opentracing import trace
|
||||
from synapse.storage.controllers import StorageControllers
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.types import RetentionPolicy, StateMap, get_domain_from_id
|
||||
from synapse.types import RetentionPolicy, StateMap, StrCollection, get_domain_from_id
|
||||
from synapse.types.state import StateFilter
|
||||
from synapse.util import Clock
|
||||
|
||||
|
@ -150,12 +150,12 @@ async def filter_events_for_client(
|
|||
|
||||
async def filter_event_for_clients_with_state(
|
||||
store: DataStore,
|
||||
user_ids: Collection[str],
|
||||
user_ids: StrCollection,
|
||||
event: EventBase,
|
||||
context: EventContext,
|
||||
is_peeking: bool = False,
|
||||
filter_send_to_client: bool = True,
|
||||
) -> Collection[str]:
|
||||
) -> StrCollection:
|
||||
"""
|
||||
Checks to see if an event is visible to the users in the list at the time of
|
||||
the event.
|
||||
|
|
|
@ -188,8 +188,11 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
app_service.is_interested_in_user = Mock(return_value=True)
|
||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||
# This just needs to return a truth-y value.
|
||||
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
|
||||
|
||||
class FakeUserInfo:
|
||||
is_guest = False
|
||||
|
||||
self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
|
||||
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
||||
|
||||
request = Mock(args={})
|
||||
|
@ -341,7 +344,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
def test_get_guest_user_from_macaroon(self) -> None:
|
||||
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True})
|
||||
class FakeUserInfo:
|
||||
is_guest = True
|
||||
|
||||
self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
|
||||
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
||||
|
||||
user_id = "@baldrick:matrix.org"
|
||||
|
|
|
@ -461,6 +461,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
|
|||
self.message_handler = hs.get_device_message_handler()
|
||||
self.registration = hs.get_registration_handler()
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.store = hs.get_datastores().main
|
||||
return hs
|
||||
|
||||
|
@ -487,11 +488,12 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
|
||||
|
||||
# Create a new login for the user and dehydrated the device
|
||||
device_id, access_token, _expiration_time, _refresh_token = self.get_success(
|
||||
device_id, access_token, _expiration_time, refresh_token = self.get_success(
|
||||
self.registration.register_device(
|
||||
user_id=user_id,
|
||||
device_id=None,
|
||||
initial_display_name="new device",
|
||||
should_issue_refresh_token=True,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -522,6 +524,12 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.assertEqual(user_info.device_id, retrieved_device_id)
|
||||
|
||||
# make sure the user device has the refresh token
|
||||
assert refresh_token is not None
|
||||
self.get_success(
|
||||
self.auth_handler.refresh_token(refresh_token, 5 * 60 * 1000, 5 * 60 * 1000)
|
||||
)
|
||||
|
||||
# make sure the device has the display name that was set from the login
|
||||
res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ from twisted.test.proto_helpers import MemoryReactor
|
|||
from synapse.api.constants import UserTypes
|
||||
from synapse.api.errors import ThreepidValidationError
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.types import JsonDict, UserID, UserInfo
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.unittest import HomeserverTestCase, override_config
|
||||
|
@ -35,24 +35,22 @@ class RegistrationStoreTestCase(HomeserverTestCase):
|
|||
self.get_success(self.store.register_user(self.user_id, self.pwhash))
|
||||
|
||||
self.assertEqual(
|
||||
{
|
||||
UserInfo(
|
||||
# TODO(paul): Surely this field should be 'user_id', not 'name'
|
||||
"name": self.user_id,
|
||||
"password_hash": self.pwhash,
|
||||
"admin": 0,
|
||||
"is_guest": 0,
|
||||
"consent_version": None,
|
||||
"consent_ts": None,
|
||||
"consent_server_notice_sent": None,
|
||||
"appservice_id": None,
|
||||
"creation_ts": 0,
|
||||
"user_type": None,
|
||||
"deactivated": 0,
|
||||
"locked": 0,
|
||||
"shadow_banned": 0,
|
||||
"approved": 1,
|
||||
"last_seen_ts": None,
|
||||
},
|
||||
user_id=UserID.from_string(self.user_id),
|
||||
is_admin=False,
|
||||
is_guest=False,
|
||||
consent_server_notice_sent=None,
|
||||
consent_ts=None,
|
||||
consent_version=None,
|
||||
appservice_id=None,
|
||||
creation_ts=0,
|
||||
user_type=None,
|
||||
is_deactivated=False,
|
||||
locked=False,
|
||||
is_shadow_banned=False,
|
||||
approved=True,
|
||||
),
|
||||
(self.get_success(self.store.get_user_by_id(self.user_id))),
|
||||
)
|
||||
|
||||
|
@ -65,9 +63,11 @@ class RegistrationStoreTestCase(HomeserverTestCase):
|
|||
|
||||
user = self.get_success(self.store.get_user_by_id(self.user_id))
|
||||
assert user
|
||||
self.assertEqual(user["consent_version"], "1")
|
||||
self.assertGreater(user["consent_ts"], before_consent)
|
||||
self.assertLess(user["consent_ts"], self.clock.time_msec())
|
||||
self.assertEqual(user.consent_version, "1")
|
||||
self.assertIsNotNone(user.consent_ts)
|
||||
assert user.consent_ts is not None
|
||||
self.assertGreater(user.consent_ts, before_consent)
|
||||
self.assertLess(user.consent_ts, self.clock.time_msec())
|
||||
|
||||
def test_add_tokens(self) -> None:
|
||||
self.get_success(self.store.register_user(self.user_id, self.pwhash))
|
||||
|
@ -215,7 +215,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
|
|||
|
||||
user = self.get_success(self.store.get_user_by_id(self.user_id))
|
||||
assert user is not None
|
||||
self.assertTrue(user["approved"])
|
||||
self.assertTrue(user.approved)
|
||||
|
||||
approved = self.get_success(self.store.is_user_approved(self.user_id))
|
||||
self.assertTrue(approved)
|
||||
|
@ -228,7 +228,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
|
|||
|
||||
user = self.get_success(self.store.get_user_by_id(self.user_id))
|
||||
assert user is not None
|
||||
self.assertFalse(user["approved"])
|
||||
self.assertFalse(user.approved)
|
||||
|
||||
approved = self.get_success(self.store.is_user_approved(self.user_id))
|
||||
self.assertFalse(approved)
|
||||
|
@ -248,7 +248,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
|
|||
user = self.get_success(self.store.get_user_by_id(self.user_id))
|
||||
self.assertIsNotNone(user)
|
||||
assert user is not None
|
||||
self.assertEqual(user["approved"], 1)
|
||||
self.assertEqual(user.approved, 1)
|
||||
|
||||
approved = self.get_success(self.store.is_user_approved(self.user_id))
|
||||
self.assertTrue(approved)
|
||||
|
|
Loading…
Reference in a new issue