mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-23 10:05:55 +03:00
Format files with Ruff (#17643)
I thought ruff check would also format, but it doesn't. This runs ruff format in CI and dev scripts. The first commit is just a run of `ruff format .` in the root directory.
This commit is contained in:
parent
709b7363fe
commit
7d52ce7d4b
152 changed files with 526 additions and 492 deletions
6
.github/workflows/fix_lint.yaml
vendored
6
.github/workflows/fix_lint.yaml
vendored
|
@ -29,10 +29,14 @@ jobs:
|
||||||
with:
|
with:
|
||||||
install-project: "false"
|
install-project: "false"
|
||||||
|
|
||||||
- name: Run ruff
|
- name: Run ruff check
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
run: poetry run ruff check --fix .
|
run: poetry run ruff check --fix .
|
||||||
|
|
||||||
|
- name: Run ruff format
|
||||||
|
continue-on-error: true
|
||||||
|
run: poetry run ruff format --quiet .
|
||||||
|
|
||||||
- run: cargo clippy --all-features --fix -- -D warnings
|
- run: cargo clippy --all-features --fix -- -D warnings
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
|
|
||||||
|
|
5
.github/workflows/tests.yml
vendored
5
.github/workflows/tests.yml
vendored
|
@ -131,9 +131,12 @@ jobs:
|
||||||
with:
|
with:
|
||||||
install-project: "false"
|
install-project: "false"
|
||||||
|
|
||||||
- name: Check style
|
- name: Run ruff check
|
||||||
run: poetry run ruff check --output-format=github .
|
run: poetry run ruff check --output-format=github .
|
||||||
|
|
||||||
|
- name: Run ruff format
|
||||||
|
run: poetry run ruff format --check .
|
||||||
|
|
||||||
lint-mypy:
|
lint-mypy:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
name: Typechecking
|
name: Typechecking
|
||||||
|
|
1
changelog.d/17643.misc
Normal file
1
changelog.d/17643.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Replace `isort` and `black with `ruff`.
|
|
@ -22,6 +22,7 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
"""Starts a synapse client console."""
|
"""Starts a synapse client console."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import binascii
|
import binascii
|
||||||
import cmd
|
import cmd
|
||||||
|
|
|
@ -31,6 +31,7 @@ Pydantic does not yet offer a strict mode, but it is planned for pydantic v2. Se
|
||||||
until then, this script is a best effort to stop us from introducing type coersion bugs
|
until then, this script is a best effort to stop us from introducing type coersion bugs
|
||||||
(like the infamous stringy power levels fixed in room version 10).
|
(like the infamous stringy power levels fixed in room version 10).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import contextlib
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
|
|
|
@ -109,6 +109,9 @@ set -x
|
||||||
# --quiet suppresses the update check.
|
# --quiet suppresses the update check.
|
||||||
ruff check --quiet --fix "${files[@]}"
|
ruff check --quiet --fix "${files[@]}"
|
||||||
|
|
||||||
|
# Reformat Python code.
|
||||||
|
ruff format --quiet "${files[@]}"
|
||||||
|
|
||||||
# Catch any common programming mistakes in Rust code.
|
# Catch any common programming mistakes in Rust code.
|
||||||
#
|
#
|
||||||
# --bins, --examples, --lib, --tests combined explicitly disable checking
|
# --bins, --examples, --lib, --tests combined explicitly disable checking
|
||||||
|
|
|
@ -20,8 +20,7 @@
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
|
|
||||||
"""An interactive script for doing a release. See `cli()` below.
|
"""An interactive script for doing a release. See `cli()` below."""
|
||||||
"""
|
|
||||||
|
|
||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
|
|
|
@ -13,8 +13,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Contains *incomplete* type hints for txredisapi.
|
"""Contains *incomplete* type hints for txredisapi."""
|
||||||
"""
|
|
||||||
from typing import Any, List, Optional, Type, Union
|
from typing import Any, List, Optional, Type, Union
|
||||||
|
|
||||||
from twisted.internet import protocol
|
from twisted.internet import protocol
|
||||||
|
|
|
@ -20,8 +20,7 @@
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
|
|
||||||
""" This is an implementation of a Matrix homeserver.
|
"""This is an implementation of a Matrix homeserver."""
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
|
@ -171,7 +171,7 @@ def elide_http_methods_if_unconflicting(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def paths_to_methods_dict(
|
def paths_to_methods_dict(
|
||||||
methods_and_paths: Iterable[Tuple[str, str]]
|
methods_and_paths: Iterable[Tuple[str, str]],
|
||||||
) -> Dict[str, Set[str]]:
|
) -> Dict[str, Set[str]]:
|
||||||
"""
|
"""
|
||||||
Given (method, path) pairs, produces a dict from path to set of methods
|
Given (method, path) pairs, produces a dict from path to set of methods
|
||||||
|
@ -201,7 +201,7 @@ def elide_http_methods_if_unconflicting(
|
||||||
|
|
||||||
|
|
||||||
def simplify_path_regexes(
|
def simplify_path_regexes(
|
||||||
registrations: Dict[Tuple[str, str], EndpointDescription]
|
registrations: Dict[Tuple[str, str], EndpointDescription],
|
||||||
) -> Dict[Tuple[str, str], EndpointDescription]:
|
) -> Dict[Tuple[str, str], EndpointDescription]:
|
||||||
"""
|
"""
|
||||||
Simplify all the path regexes for the dict of endpoint descriptions,
|
Simplify all the path regexes for the dict of endpoint descriptions,
|
||||||
|
|
|
@ -40,6 +40,7 @@ from synapse.storage.engines import create_engine
|
||||||
|
|
||||||
class ReviewConfig(RootConfig):
|
class ReviewConfig(RootConfig):
|
||||||
"A config class that just pulls out the database config"
|
"A config class that just pulls out the database config"
|
||||||
|
|
||||||
config_classes = [DatabaseConfig]
|
config_classes = [DatabaseConfig]
|
||||||
|
|
||||||
|
|
||||||
|
@ -160,7 +161,11 @@ def main() -> None:
|
||||||
|
|
||||||
with make_conn(database_config, engine, "review_recent_signups") as db_conn:
|
with make_conn(database_config, engine, "review_recent_signups") as db_conn:
|
||||||
# This generates a type of Cursor, not LoggingTransaction.
|
# This generates a type of Cursor, not LoggingTransaction.
|
||||||
user_infos = get_recent_users(db_conn.cursor(), since_ms, exclude_users_with_appservice) # type: ignore[arg-type]
|
user_infos = get_recent_users(
|
||||||
|
db_conn.cursor(),
|
||||||
|
since_ms, # type: ignore[arg-type]
|
||||||
|
exclude_users_with_appservice,
|
||||||
|
)
|
||||||
|
|
||||||
for user_info in user_infos:
|
for user_info in user_infos:
|
||||||
if exclude_users_with_email and user_info.emails:
|
if exclude_users_with_email and user_info.emails:
|
||||||
|
|
|
@ -717,9 +717,7 @@ class Porter:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if all background updates are done, abort if not.
|
# Check if all background updates are done, abort if not.
|
||||||
updates_complete = (
|
updates_complete = await self.sqlite_store.db_pool.updates.has_completed_background_updates()
|
||||||
await self.sqlite_store.db_pool.updates.has_completed_background_updates()
|
|
||||||
)
|
|
||||||
if not updates_complete:
|
if not updates_complete:
|
||||||
end_error = (
|
end_error = (
|
||||||
"Pending background updates exist in the SQLite3 database."
|
"Pending background updates exist in the SQLite3 database."
|
||||||
|
@ -1095,11 +1093,11 @@ class Porter:
|
||||||
return done, remaining + done
|
return done, remaining + done
|
||||||
|
|
||||||
async def _setup_state_group_id_seq(self) -> None:
|
async def _setup_state_group_id_seq(self) -> None:
|
||||||
curr_id: Optional[int] = (
|
curr_id: Optional[
|
||||||
await self.sqlite_store.db_pool.simple_select_one_onecol(
|
int
|
||||||
|
] = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||||
table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True
|
table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if not curr_id:
|
if not curr_id:
|
||||||
return
|
return
|
||||||
|
@ -1186,14 +1184,14 @@ class Porter:
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _setup_auth_chain_sequence(self) -> None:
|
async def _setup_auth_chain_sequence(self) -> None:
|
||||||
curr_chain_id: Optional[int] = (
|
curr_chain_id: Optional[
|
||||||
await self.sqlite_store.db_pool.simple_select_one_onecol(
|
int
|
||||||
|
] = await self.sqlite_store.db_pool.simple_select_one_onecol(
|
||||||
table="event_auth_chains",
|
table="event_auth_chains",
|
||||||
keyvalues={},
|
keyvalues={},
|
||||||
retcol="MAX(chain_id)",
|
retcol="MAX(chain_id)",
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
def r(txn: LoggingTransaction) -> None:
|
def r(txn: LoggingTransaction) -> None:
|
||||||
# Presumably there is at least one row in event_auth_chains.
|
# Presumably there is at least one row in event_auth_chains.
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
"""Contains the URL paths to prefix various aspects of the server with."""
|
"""Contains the URL paths to prefix various aspects of the server with."""
|
||||||
|
|
||||||
import hmac
|
import hmac
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
|
@ -54,6 +54,7 @@ UP & quit +---------- YES SUCCESS
|
||||||
This is all tied together by the AppServiceScheduler which DIs the required
|
This is all tied together by the AppServiceScheduler which DIs the required
|
||||||
components.
|
components.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
|
|
|
@ -200,16 +200,13 @@ class KeyConfig(Config):
|
||||||
)
|
)
|
||||||
form_secret = 'form_secret: "%s"' % random_string_with_symbols(50)
|
form_secret = 'form_secret: "%s"' % random_string_with_symbols(50)
|
||||||
|
|
||||||
return (
|
return """\
|
||||||
"""\
|
|
||||||
%(macaroon_secret_key)s
|
%(macaroon_secret_key)s
|
||||||
%(form_secret)s
|
%(form_secret)s
|
||||||
signing_key_path: "%(base_key_name)s.signing.key"
|
signing_key_path: "%(base_key_name)s.signing.key"
|
||||||
trusted_key_servers:
|
trusted_key_servers:
|
||||||
- server_name: "matrix.org"
|
- server_name: "matrix.org"
|
||||||
"""
|
""" % locals()
|
||||||
% locals()
|
|
||||||
)
|
|
||||||
|
|
||||||
def read_signing_keys(self, signing_key_path: str, name: str) -> List[SigningKey]:
|
def read_signing_keys(self, signing_key_path: str, name: str) -> List[SigningKey]:
|
||||||
"""Read the signing keys in the given path.
|
"""Read the signing keys in the given path.
|
||||||
|
@ -249,7 +246,9 @@ class KeyConfig(Config):
|
||||||
if is_signing_algorithm_supported(key_id):
|
if is_signing_algorithm_supported(key_id):
|
||||||
key_base64 = key_data["key"]
|
key_base64 = key_data["key"]
|
||||||
key_bytes = decode_base64(key_base64)
|
key_bytes = decode_base64(key_base64)
|
||||||
verify_key: "VerifyKeyWithExpiry" = decode_verify_key_bytes(key_id, key_bytes) # type: ignore[assignment]
|
verify_key: "VerifyKeyWithExpiry" = decode_verify_key_bytes(
|
||||||
|
key_id, key_bytes
|
||||||
|
) # type: ignore[assignment]
|
||||||
verify_key.expired = key_data["expired_ts"]
|
verify_key.expired = key_data["expired_ts"]
|
||||||
keys[key_id] = verify_key
|
keys[key_id] = verify_key
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -157,12 +157,9 @@ class LoggingConfig(Config):
|
||||||
self, config_dir_path: str, server_name: str, **kwargs: Any
|
self, config_dir_path: str, server_name: str, **kwargs: Any
|
||||||
) -> str:
|
) -> str:
|
||||||
log_config = os.path.join(config_dir_path, server_name + ".log.config")
|
log_config = os.path.join(config_dir_path, server_name + ".log.config")
|
||||||
return (
|
return """\
|
||||||
"""\
|
|
||||||
log_config: "%(log_config)s"
|
log_config: "%(log_config)s"
|
||||||
"""
|
""" % locals()
|
||||||
% locals()
|
|
||||||
)
|
|
||||||
|
|
||||||
def read_arguments(self, args: argparse.Namespace) -> None:
|
def read_arguments(self, args: argparse.Namespace) -> None:
|
||||||
if args.no_redirect_stdio is not None:
|
if args.no_redirect_stdio is not None:
|
||||||
|
|
|
@ -828,13 +828,10 @@ class ServerConfig(Config):
|
||||||
).lstrip()
|
).lstrip()
|
||||||
|
|
||||||
if not unsecure_listeners:
|
if not unsecure_listeners:
|
||||||
unsecure_http_bindings = (
|
unsecure_http_bindings = """- port: %(unsecure_port)s
|
||||||
"""- port: %(unsecure_port)s
|
|
||||||
tls: false
|
tls: false
|
||||||
type: http
|
type: http
|
||||||
x_forwarded: true"""
|
x_forwarded: true""" % locals()
|
||||||
% locals()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not open_private_ports:
|
if not open_private_ports:
|
||||||
unsecure_http_bindings += (
|
unsecure_http_bindings += (
|
||||||
|
@ -853,16 +850,13 @@ class ServerConfig(Config):
|
||||||
if not secure_listeners:
|
if not secure_listeners:
|
||||||
secure_http_bindings = ""
|
secure_http_bindings = ""
|
||||||
|
|
||||||
return (
|
return """\
|
||||||
"""\
|
|
||||||
server_name: "%(server_name)s"
|
server_name: "%(server_name)s"
|
||||||
pid_file: %(pid_file)s
|
pid_file: %(pid_file)s
|
||||||
listeners:
|
listeners:
|
||||||
%(secure_http_bindings)s
|
%(secure_http_bindings)s
|
||||||
%(unsecure_http_bindings)s
|
%(unsecure_http_bindings)s
|
||||||
"""
|
""" % locals()
|
||||||
% locals()
|
|
||||||
)
|
|
||||||
|
|
||||||
def read_arguments(self, args: argparse.Namespace) -> None:
|
def read_arguments(self, args: argparse.Namespace) -> None:
|
||||||
if args.manhole is not None:
|
if args.manhole is not None:
|
||||||
|
|
|
@ -328,10 +328,11 @@ class WorkerConfig(Config):
|
||||||
)
|
)
|
||||||
|
|
||||||
# type-ignore: the expression `Union[A, B]` is not a Type[Union[A, B]] currently
|
# type-ignore: the expression `Union[A, B]` is not a Type[Union[A, B]] currently
|
||||||
self.instance_map: Dict[
|
self.instance_map: Dict[str, InstanceLocationConfig] = (
|
||||||
str, InstanceLocationConfig
|
parse_and_validate_mapping(
|
||||||
] = parse_and_validate_mapping(
|
instance_map,
|
||||||
instance_map, InstanceLocationConfig # type: ignore[arg-type]
|
InstanceLocationConfig, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Map from type of streams to source, c.f. WriterLocations.
|
# Map from type of streams to source, c.f. WriterLocations.
|
||||||
|
|
|
@ -887,7 +887,8 @@ def _check_power_levels(
|
||||||
raise SynapseError(400, f"{v!r} must be an integer.")
|
raise SynapseError(400, f"{v!r} must be an integer.")
|
||||||
if k in {"events", "notifications", "users"}:
|
if k in {"events", "notifications", "users"}:
|
||||||
if not isinstance(v, collections.abc.Mapping) or not all(
|
if not isinstance(v, collections.abc.Mapping) or not all(
|
||||||
type(v) is int for v in v.values() # noqa: E721
|
type(v) is int
|
||||||
|
for v in v.values() # noqa: E721
|
||||||
):
|
):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400,
|
400,
|
||||||
|
|
|
@ -80,7 +80,7 @@ def load_legacy_presence_router(hs: "HomeServer") -> None:
|
||||||
# All methods that the module provides should be async, but this wasn't enforced
|
# All methods that the module provides should be async, but this wasn't enforced
|
||||||
# in the old module system, so we wrap them if needed
|
# in the old module system, so we wrap them if needed
|
||||||
def async_wrapper(
|
def async_wrapper(
|
||||||
f: Optional[Callable[P, R]]
|
f: Optional[Callable[P, R]],
|
||||||
) -> Optional[Callable[P, Awaitable[R]]]:
|
) -> Optional[Callable[P, Awaitable[R]]]:
|
||||||
# f might be None if the callback isn't implemented by the module. In this
|
# f might be None if the callback isn't implemented by the module. In this
|
||||||
# case we don't want to register a callback at all so we return None.
|
# case we don't want to register a callback at all so we return None.
|
||||||
|
|
|
@ -504,7 +504,7 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
|
||||||
|
|
||||||
|
|
||||||
def _encode_state_group_delta(
|
def _encode_state_group_delta(
|
||||||
state_group_delta: Dict[Tuple[int, int], StateMap[str]]
|
state_group_delta: Dict[Tuple[int, int], StateMap[str]],
|
||||||
) -> List[Tuple[int, int, Optional[List[Tuple[str, str, str]]]]]:
|
) -> List[Tuple[int, int, Optional[List[Tuple[str, str, str]]]]]:
|
||||||
if not state_group_delta:
|
if not state_group_delta:
|
||||||
return []
|
return []
|
||||||
|
@ -517,7 +517,7 @@ def _encode_state_group_delta(
|
||||||
|
|
||||||
|
|
||||||
def _decode_state_group_delta(
|
def _decode_state_group_delta(
|
||||||
input: List[Tuple[int, int, List[Tuple[str, str, str]]]]
|
input: List[Tuple[int, int, List[Tuple[str, str, str]]]],
|
||||||
) -> Dict[Tuple[int, int], StateMap[str]]:
|
) -> Dict[Tuple[int, int], StateMap[str]]:
|
||||||
if not input:
|
if not input:
|
||||||
return {}
|
return {}
|
||||||
|
@ -544,7 +544,7 @@ def _encode_state_dict(
|
||||||
|
|
||||||
|
|
||||||
def _decode_state_dict(
|
def _decode_state_dict(
|
||||||
input: Optional[List[Tuple[str, str, str]]]
|
input: Optional[List[Tuple[str, str, str]]],
|
||||||
) -> Optional[StateMap[str]]:
|
) -> Optional[StateMap[str]]:
|
||||||
"""Decodes a state dict encoded using `_encode_state_dict` above"""
|
"""Decodes a state dict encoded using `_encode_state_dict` above"""
|
||||||
if input is None:
|
if input is None:
|
||||||
|
|
|
@ -19,5 +19,4 @@
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
|
|
||||||
""" This package includes all the federation specific logic.
|
"""This package includes all the federation specific logic."""
|
||||||
"""
|
|
||||||
|
|
|
@ -859,7 +859,6 @@ class FederationMediaThumbnailServlet(BaseFederationServerServlet):
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
media_id: str,
|
media_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
width = parse_integer(request, "width", required=True)
|
width = parse_integer(request, "width", required=True)
|
||||||
height = parse_integer(request, "height", required=True)
|
height = parse_integer(request, "height", required=True)
|
||||||
method = parse_string(request, "method", "scale")
|
method = parse_string(request, "method", "scale")
|
||||||
|
|
|
@ -118,11 +118,11 @@ class AccountHandler:
|
||||||
}
|
}
|
||||||
|
|
||||||
if self._use_account_validity_in_account_status:
|
if self._use_account_validity_in_account_status:
|
||||||
status["org.matrix.expired"] = (
|
status[
|
||||||
await self._account_validity_handler.is_user_expired(
|
"org.matrix.expired"
|
||||||
|
] = await self._account_validity_handler.is_user_expired(
|
||||||
user_id.to_string()
|
user_id.to_string()
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return status
|
return status
|
||||||
|
|
||||||
|
|
|
@ -197,15 +197,16 @@ class AdminHandler:
|
||||||
# events that we have and then filtering, this isn't the most
|
# events that we have and then filtering, this isn't the most
|
||||||
# efficient method perhaps but it does guarantee we get everything.
|
# efficient method perhaps but it does guarantee we get everything.
|
||||||
while True:
|
while True:
|
||||||
events, _ = (
|
(
|
||||||
await self._store.paginate_room_events_by_topological_ordering(
|
events,
|
||||||
|
_,
|
||||||
|
) = await self._store.paginate_room_events_by_topological_ordering(
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
from_key=from_key,
|
from_key=from_key,
|
||||||
to_key=to_key,
|
to_key=to_key,
|
||||||
limit=100,
|
limit=100,
|
||||||
direction=Direction.FORWARDS,
|
direction=Direction.FORWARDS,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if not events:
|
if not events:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
|
@ -166,8 +166,7 @@ def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]:
|
||||||
if "country" not in identifier or (
|
if "country" not in identifier or (
|
||||||
# The specification requires a "phone" field, while Synapse used to require a "number"
|
# The specification requires a "phone" field, while Synapse used to require a "number"
|
||||||
# field. Accept both for backwards compatibility.
|
# field. Accept both for backwards compatibility.
|
||||||
"phone" not in identifier
|
"phone" not in identifier and "number" not in identifier
|
||||||
and "number" not in identifier
|
|
||||||
):
|
):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400, "Invalid phone-type identifier", errcode=Codes.INVALID_PARAM
|
400, "Invalid phone-type identifier", errcode=Codes.INVALID_PARAM
|
||||||
|
|
|
@ -265,9 +265,9 @@ class DirectoryHandler:
|
||||||
async def get_association(self, room_alias: RoomAlias) -> JsonDict:
|
async def get_association(self, room_alias: RoomAlias) -> JsonDict:
|
||||||
room_id = None
|
room_id = None
|
||||||
if self.hs.is_mine(room_alias):
|
if self.hs.is_mine(room_alias):
|
||||||
result: Optional[RoomAliasMapping] = (
|
result: Optional[
|
||||||
await self.get_association_from_room_alias(room_alias)
|
RoomAliasMapping
|
||||||
)
|
] = await self.get_association_from_room_alias(room_alias)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
room_id = result.room_id
|
room_id = result.room_id
|
||||||
|
@ -512,13 +512,11 @@ class DirectoryHandler:
|
||||||
raise SynapseError(403, "Not allowed to publish room")
|
raise SynapseError(403, "Not allowed to publish room")
|
||||||
|
|
||||||
# Check if publishing is blocked by a third party module
|
# Check if publishing is blocked by a third party module
|
||||||
allowed_by_third_party_rules = (
|
allowed_by_third_party_rules = await (
|
||||||
await (
|
|
||||||
self._third_party_event_rules.check_visibility_can_be_modified(
|
self._third_party_event_rules.check_visibility_can_be_modified(
|
||||||
room_id, visibility
|
room_id, visibility
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if not allowed_by_third_party_rules:
|
if not allowed_by_third_party_rules:
|
||||||
raise SynapseError(403, "Not allowed to publish room")
|
raise SynapseError(403, "Not allowed to publish room")
|
||||||
|
|
||||||
|
|
|
@ -1001,12 +1001,12 @@ class FederationHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
if include_auth_user_id:
|
if include_auth_user_id:
|
||||||
event_content[EventContentFields.AUTHORISING_USER] = (
|
event_content[
|
||||||
await self._event_auth_handler.get_user_which_could_invite(
|
EventContentFields.AUTHORISING_USER
|
||||||
|
] = await self._event_auth_handler.get_user_which_could_invite(
|
||||||
room_id,
|
room_id,
|
||||||
state_ids,
|
state_ids,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
builder = self.event_builder_factory.for_room_version(
|
builder = self.event_builder_factory.for_room_version(
|
||||||
room_version,
|
room_version,
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
"""Utilities for interacting with Identity Servers"""
|
"""Utilities for interacting with Identity Servers"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
|
@ -1225,10 +1225,9 @@ class EventCreationHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
if prev_event_ids is not None:
|
if prev_event_ids is not None:
|
||||||
assert (
|
assert len(prev_event_ids) <= 10, (
|
||||||
len(prev_event_ids) <= 10
|
"Attempting to create an event with %i prev_events"
|
||||||
), "Attempting to create an event with %i prev_events" % (
|
% (len(prev_event_ids),)
|
||||||
len(prev_event_ids),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
|
prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
|
||||||
|
|
|
@ -507,8 +507,10 @@ class PaginationHandler:
|
||||||
|
|
||||||
# Initially fetch the events from the database. With any luck, we can return
|
# Initially fetch the events from the database. With any luck, we can return
|
||||||
# these without blocking on backfill (handled below).
|
# these without blocking on backfill (handled below).
|
||||||
events, next_key = (
|
(
|
||||||
await self.store.paginate_room_events_by_topological_ordering(
|
events,
|
||||||
|
next_key,
|
||||||
|
) = await self.store.paginate_room_events_by_topological_ordering(
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
from_key=from_token.room_key,
|
from_key=from_token.room_key,
|
||||||
to_key=to_room_key,
|
to_key=to_room_key,
|
||||||
|
@ -516,7 +518,6 @@ class PaginationHandler:
|
||||||
limit=pagin_config.limit,
|
limit=pagin_config.limit,
|
||||||
event_filter=event_filter,
|
event_filter=event_filter,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if pagin_config.direction == Direction.BACKWARDS:
|
if pagin_config.direction == Direction.BACKWARDS:
|
||||||
# We use a `Set` because there can be multiple events at a given depth
|
# We use a `Set` because there can be multiple events at a given depth
|
||||||
|
@ -584,8 +585,10 @@ class PaginationHandler:
|
||||||
# If we did backfill something, refetch the events from the database to
|
# If we did backfill something, refetch the events from the database to
|
||||||
# catch anything new that might have been added since we last fetched.
|
# catch anything new that might have been added since we last fetched.
|
||||||
if did_backfill:
|
if did_backfill:
|
||||||
events, next_key = (
|
(
|
||||||
await self.store.paginate_room_events_by_topological_ordering(
|
events,
|
||||||
|
next_key,
|
||||||
|
) = await self.store.paginate_room_events_by_topological_ordering(
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
from_key=from_token.room_key,
|
from_key=from_token.room_key,
|
||||||
to_key=to_room_key,
|
to_key=to_room_key,
|
||||||
|
@ -593,7 +596,6 @@ class PaginationHandler:
|
||||||
limit=pagin_config.limit,
|
limit=pagin_config.limit,
|
||||||
event_filter=event_filter,
|
event_filter=event_filter,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Otherwise, we can backfill in the background for eventual
|
# Otherwise, we can backfill in the background for eventual
|
||||||
# consistency's sake but we don't need to block the client waiting
|
# consistency's sake but we don't need to block the client waiting
|
||||||
|
|
|
@ -71,6 +71,7 @@ user state; this device follows the normal timeout logic (see above) and will
|
||||||
automatically be replaced with any information from currently available devices.
|
automatically be replaced with any information from currently available devices.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import contextlib
|
import contextlib
|
||||||
import itertools
|
import itertools
|
||||||
|
@ -493,9 +494,9 @@ class WorkerPresenceHandler(BasePresenceHandler):
|
||||||
|
|
||||||
# The number of ongoing syncs on this process, by (user ID, device ID).
|
# The number of ongoing syncs on this process, by (user ID, device ID).
|
||||||
# Empty if _presence_enabled is false.
|
# Empty if _presence_enabled is false.
|
||||||
self._user_device_to_num_current_syncs: Dict[Tuple[str, Optional[str]], int] = (
|
self._user_device_to_num_current_syncs: Dict[
|
||||||
{}
|
Tuple[str, Optional[str]], int
|
||||||
)
|
] = {}
|
||||||
|
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
self.instance_id = hs.get_instance_id()
|
self.instance_id = hs.get_instance_id()
|
||||||
|
@ -818,9 +819,9 @@ class PresenceHandler(BasePresenceHandler):
|
||||||
|
|
||||||
# Keeps track of the number of *ongoing* syncs on this process. While
|
# Keeps track of the number of *ongoing* syncs on this process. While
|
||||||
# this is non zero a user will never go offline.
|
# this is non zero a user will never go offline.
|
||||||
self._user_device_to_num_current_syncs: Dict[Tuple[str, Optional[str]], int] = (
|
self._user_device_to_num_current_syncs: Dict[
|
||||||
{}
|
Tuple[str, Optional[str]], int
|
||||||
)
|
] = {}
|
||||||
|
|
||||||
# Keeps track of the number of *ongoing* syncs on other processes.
|
# Keeps track of the number of *ongoing* syncs on other processes.
|
||||||
#
|
#
|
||||||
|
|
|
@ -351,9 +351,9 @@ class ProfileHandler:
|
||||||
server_name = host
|
server_name = host
|
||||||
|
|
||||||
if self._is_mine_server_name(server_name):
|
if self._is_mine_server_name(server_name):
|
||||||
media_info: Optional[Union[LocalMedia, RemoteMedia]] = (
|
media_info: Optional[
|
||||||
await self.store.get_local_media(media_id)
|
Union[LocalMedia, RemoteMedia]
|
||||||
)
|
] = await self.store.get_local_media(media_id)
|
||||||
else:
|
else:
|
||||||
media_info = await self.store.get_cached_remote_media(server_name, media_id)
|
media_info = await self.store.get_cached_remote_media(server_name, media_id)
|
||||||
|
|
||||||
|
|
|
@ -188,14 +188,14 @@ class RelationsHandler:
|
||||||
if include_original_event:
|
if include_original_event:
|
||||||
# Do not bundle aggregations when retrieving the original event because
|
# Do not bundle aggregations when retrieving the original event because
|
||||||
# we want the content before relations are applied to it.
|
# we want the content before relations are applied to it.
|
||||||
return_value["original_event"] = (
|
return_value[
|
||||||
await self._event_serializer.serialize_event(
|
"original_event"
|
||||||
|
] = await self._event_serializer.serialize_event(
|
||||||
event,
|
event,
|
||||||
now,
|
now,
|
||||||
bundle_aggregations=None,
|
bundle_aggregations=None,
|
||||||
config=serialize_options,
|
config=serialize_options,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if next_token:
|
if next_token:
|
||||||
return_value["next_batch"] = await next_token.to_string(self._main_store)
|
return_value["next_batch"] = await next_token.to_string(self._main_store)
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
"""Contains functions for performing actions on rooms."""
|
"""Contains functions for performing actions on rooms."""
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
@ -900,13 +901,11 @@ class RoomCreationHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check whether this visibility value is blocked by a third party module
|
# Check whether this visibility value is blocked by a third party module
|
||||||
allowed_by_third_party_rules = (
|
allowed_by_third_party_rules = await (
|
||||||
await (
|
|
||||||
self._third_party_event_rules.check_visibility_can_be_modified(
|
self._third_party_event_rules.check_visibility_can_be_modified(
|
||||||
room_id, visibility
|
room_id, visibility
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if not allowed_by_third_party_rules:
|
if not allowed_by_third_party_rules:
|
||||||
raise SynapseError(403, "Room visibility value not allowed.")
|
raise SynapseError(403, "Room visibility value not allowed.")
|
||||||
|
|
||||||
|
|
|
@ -1302,12 +1302,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
# If this is going to be a local join, additional information must
|
# If this is going to be a local join, additional information must
|
||||||
# be included in the event content in order to efficiently validate
|
# be included in the event content in order to efficiently validate
|
||||||
# the event.
|
# the event.
|
||||||
content[EventContentFields.AUTHORISING_USER] = (
|
content[
|
||||||
await self.event_auth_handler.get_user_which_could_invite(
|
EventContentFields.AUTHORISING_USER
|
||||||
|
] = await self.event_auth_handler.get_user_which_could_invite(
|
||||||
room_id,
|
room_id,
|
||||||
state_before_join,
|
state_before_join,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return False, []
|
return False, []
|
||||||
|
|
||||||
|
@ -1415,9 +1415,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
if requester is not None:
|
if requester is not None:
|
||||||
sender = UserID.from_string(event.sender)
|
sender = UserID.from_string(event.sender)
|
||||||
assert (
|
assert sender == requester.user, (
|
||||||
sender == requester.user
|
"Sender (%s) must be same as requester (%s)" % (sender, requester.user)
|
||||||
), "Sender (%s) must be same as requester (%s)" % (sender, requester.user)
|
)
|
||||||
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
|
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
|
||||||
else:
|
else:
|
||||||
requester = types.create_requester(target_user)
|
requester = types.create_requester(target_user)
|
||||||
|
|
|
@ -423,9 +423,9 @@ class SearchHandler:
|
||||||
}
|
}
|
||||||
|
|
||||||
if search_result.room_groups and "room_id" in group_keys:
|
if search_result.room_groups and "room_id" in group_keys:
|
||||||
rooms_cat_res.setdefault("groups", {})[
|
rooms_cat_res.setdefault("groups", {})["room_id"] = (
|
||||||
"room_id"
|
search_result.room_groups
|
||||||
] = search_result.room_groups
|
)
|
||||||
|
|
||||||
if sender_group and "sender" in group_keys:
|
if sender_group and "sender" in group_keys:
|
||||||
rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
|
rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
|
||||||
|
|
|
@ -587,9 +587,7 @@ class SlidingSyncHandler:
|
||||||
Membership.LEAVE,
|
Membership.LEAVE,
|
||||||
Membership.BAN,
|
Membership.BAN,
|
||||||
):
|
):
|
||||||
to_bound = (
|
to_bound = room_membership_for_user_at_to_token.event_pos.to_room_stream_token()
|
||||||
room_membership_for_user_at_to_token.event_pos.to_room_stream_token()
|
|
||||||
)
|
|
||||||
|
|
||||||
timeline_from_bound = from_bound
|
timeline_from_bound = from_bound
|
||||||
if ignore_timeline_bound:
|
if ignore_timeline_bound:
|
||||||
|
|
|
@ -386,9 +386,9 @@ class SlidingSyncExtensionHandler:
|
||||||
if have_push_rules_changed:
|
if have_push_rules_changed:
|
||||||
global_account_data_map = dict(global_account_data_map)
|
global_account_data_map = dict(global_account_data_map)
|
||||||
# TODO: This should take into account the `from_token` and `to_token`
|
# TODO: This should take into account the `from_token` and `to_token`
|
||||||
global_account_data_map[AccountDataTypes.PUSH_RULES] = (
|
global_account_data_map[
|
||||||
await self.push_rules_handler.push_rules_for_user(sync_config.user)
|
AccountDataTypes.PUSH_RULES
|
||||||
)
|
] = await self.push_rules_handler.push_rules_for_user(sync_config.user)
|
||||||
else:
|
else:
|
||||||
# TODO: This should take into account the `to_token`
|
# TODO: This should take into account the `to_token`
|
||||||
all_global_account_data = await self.store.get_global_account_data_for_user(
|
all_global_account_data = await self.store.get_global_account_data_for_user(
|
||||||
|
@ -397,9 +397,9 @@ class SlidingSyncExtensionHandler:
|
||||||
|
|
||||||
global_account_data_map = dict(all_global_account_data)
|
global_account_data_map = dict(all_global_account_data)
|
||||||
# TODO: This should take into account the `to_token`
|
# TODO: This should take into account the `to_token`
|
||||||
global_account_data_map[AccountDataTypes.PUSH_RULES] = (
|
global_account_data_map[
|
||||||
await self.push_rules_handler.push_rules_for_user(sync_config.user)
|
AccountDataTypes.PUSH_RULES
|
||||||
)
|
] = await self.push_rules_handler.push_rules_for_user(sync_config.user)
|
||||||
|
|
||||||
# Fetch room account data
|
# Fetch room account data
|
||||||
account_data_by_room_map: Mapping[str, Mapping[str, JsonMapping]] = {}
|
account_data_by_room_map: Mapping[str, Mapping[str, JsonMapping]] = {}
|
||||||
|
|
|
@ -293,11 +293,12 @@ class SlidingSyncRoomLists:
|
||||||
is_encrypted=is_encrypted,
|
is_encrypted=is_encrypted,
|
||||||
)
|
)
|
||||||
|
|
||||||
newly_joined_room_ids, newly_left_room_map = (
|
(
|
||||||
await self._get_newly_joined_and_left_rooms(
|
newly_joined_room_ids,
|
||||||
|
newly_left_room_map,
|
||||||
|
) = await self._get_newly_joined_and_left_rooms(
|
||||||
user_id, from_token=from_token, to_token=to_token
|
user_id, from_token=from_token, to_token=to_token
|
||||||
)
|
)
|
||||||
)
|
|
||||||
dm_room_ids = await self._get_dm_rooms_for_user(user_id)
|
dm_room_ids = await self._get_dm_rooms_for_user(user_id)
|
||||||
|
|
||||||
# Handle state resets in the from -> to token range.
|
# Handle state resets in the from -> to token range.
|
||||||
|
@ -958,11 +959,12 @@ class SlidingSyncRoomLists:
|
||||||
else:
|
else:
|
||||||
rooms_for_user[room_id] = change_room_for_user
|
rooms_for_user[room_id] = change_room_for_user
|
||||||
|
|
||||||
newly_joined_room_ids, newly_left_room_ids = (
|
(
|
||||||
await self._get_newly_joined_and_left_rooms(
|
newly_joined_room_ids,
|
||||||
|
newly_left_room_ids,
|
||||||
|
) = await self._get_newly_joined_and_left_rooms(
|
||||||
user_id, to_token=to_token, from_token=from_token
|
user_id, to_token=to_token, from_token=from_token
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
dm_room_ids = await self._get_dm_rooms_for_user(user_id)
|
dm_room_ids = await self._get_dm_rooms_for_user(user_id)
|
||||||
|
|
||||||
|
|
|
@ -183,10 +183,7 @@ class JoinedSyncResult:
|
||||||
to tell if room needs to be part of the sync result.
|
to tell if room needs to be part of the sync result.
|
||||||
"""
|
"""
|
||||||
return bool(
|
return bool(
|
||||||
self.timeline
|
self.timeline or self.state or self.ephemeral or self.account_data
|
||||||
or self.state
|
|
||||||
or self.ephemeral
|
|
||||||
or self.account_data
|
|
||||||
# nb the notification count does not, er, count: if there's nothing
|
# nb the notification count does not, er, count: if there's nothing
|
||||||
# else in the result, we don't need to send it.
|
# else in the result, we don't need to send it.
|
||||||
)
|
)
|
||||||
|
@ -575,11 +572,11 @@ class SyncHandler:
|
||||||
if timeout == 0 or since_token is None or full_state:
|
if timeout == 0 or since_token is None or full_state:
|
||||||
# we are going to return immediately, so don't bother calling
|
# we are going to return immediately, so don't bother calling
|
||||||
# notifier.wait_for_events.
|
# notifier.wait_for_events.
|
||||||
result: Union[SyncResult, E2eeSyncResult] = (
|
result: Union[
|
||||||
await self.current_sync_for_user(
|
SyncResult, E2eeSyncResult
|
||||||
|
] = await self.current_sync_for_user(
|
||||||
sync_config, sync_version, since_token, full_state=full_state
|
sync_config, sync_version, since_token, full_state=full_state
|
||||||
)
|
)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Otherwise, we wait for something to happen and report it to the user.
|
# Otherwise, we wait for something to happen and report it to the user.
|
||||||
async def current_sync_callback(
|
async def current_sync_callback(
|
||||||
|
@ -673,11 +670,11 @@ class SyncHandler:
|
||||||
|
|
||||||
# Go through the `/sync` v2 path
|
# Go through the `/sync` v2 path
|
||||||
if sync_version == SyncVersion.SYNC_V2:
|
if sync_version == SyncVersion.SYNC_V2:
|
||||||
sync_result: Union[SyncResult, E2eeSyncResult] = (
|
sync_result: Union[
|
||||||
await self.generate_sync_result(
|
SyncResult, E2eeSyncResult
|
||||||
|
] = await self.generate_sync_result(
|
||||||
sync_config, since_token, full_state
|
sync_config, since_token, full_state
|
||||||
)
|
)
|
||||||
)
|
|
||||||
# Go through the MSC3575 Sliding Sync `/sync/e2ee` path
|
# Go through the MSC3575 Sliding Sync `/sync/e2ee` path
|
||||||
elif sync_version == SyncVersion.E2EE_SYNC:
|
elif sync_version == SyncVersion.E2EE_SYNC:
|
||||||
sync_result = await self.generate_e2ee_sync_result(
|
sync_result = await self.generate_e2ee_sync_result(
|
||||||
|
@ -1488,14 +1485,17 @@ class SyncHandler:
|
||||||
# timeline here. The caller will then dedupe any redundant
|
# timeline here. The caller will then dedupe any redundant
|
||||||
# ones.
|
# ones.
|
||||||
|
|
||||||
state_ids = await self._state_storage_controller.get_state_ids_for_event(
|
state_ids = (
|
||||||
|
await self._state_storage_controller.get_state_ids_for_event(
|
||||||
batch.events[0].event_id,
|
batch.events[0].event_id,
|
||||||
# we only want members!
|
# we only want members!
|
||||||
state_filter=StateFilter.from_types(
|
state_filter=StateFilter.from_types(
|
||||||
(EventTypes.Member, member) for member in members_to_fetch
|
(EventTypes.Member, member)
|
||||||
|
for member in members_to_fetch
|
||||||
),
|
),
|
||||||
await_full_state=False,
|
await_full_state=False,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
return state_ids
|
return state_ids
|
||||||
|
|
||||||
if batch:
|
if batch:
|
||||||
|
@ -2166,18 +2166,18 @@ class SyncHandler:
|
||||||
|
|
||||||
if push_rules_changed:
|
if push_rules_changed:
|
||||||
global_account_data = dict(global_account_data)
|
global_account_data = dict(global_account_data)
|
||||||
global_account_data[AccountDataTypes.PUSH_RULES] = (
|
global_account_data[
|
||||||
await self._push_rules_handler.push_rules_for_user(sync_config.user)
|
AccountDataTypes.PUSH_RULES
|
||||||
)
|
] = await self._push_rules_handler.push_rules_for_user(sync_config.user)
|
||||||
else:
|
else:
|
||||||
all_global_account_data = await self.store.get_global_account_data_for_user(
|
all_global_account_data = await self.store.get_global_account_data_for_user(
|
||||||
user_id
|
user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
global_account_data = dict(all_global_account_data)
|
global_account_data = dict(all_global_account_data)
|
||||||
global_account_data[AccountDataTypes.PUSH_RULES] = (
|
global_account_data[
|
||||||
await self._push_rules_handler.push_rules_for_user(sync_config.user)
|
AccountDataTypes.PUSH_RULES
|
||||||
)
|
] = await self._push_rules_handler.push_rules_for_user(sync_config.user)
|
||||||
|
|
||||||
account_data_for_user = (
|
account_data_for_user = (
|
||||||
await sync_config.filter_collection.filter_global_account_data(
|
await sync_config.filter_collection.filter_global_account_data(
|
||||||
|
|
|
@ -183,7 +183,7 @@ class WorkerLocksHandler:
|
||||||
return
|
return
|
||||||
|
|
||||||
def _wake_all_locks(
|
def _wake_all_locks(
|
||||||
locks: Collection[Union[WaitingLock, WaitingMultiLock]]
|
locks: Collection[Union[WaitingLock, WaitingMultiLock]],
|
||||||
) -> None:
|
) -> None:
|
||||||
for lock in locks:
|
for lock in locks:
|
||||||
deferred = lock.deferred
|
deferred = lock.deferred
|
||||||
|
|
|
@ -1313,6 +1313,5 @@ def is_unknown_endpoint(
|
||||||
)
|
)
|
||||||
) or (
|
) or (
|
||||||
# Older Synapses returned a 400 error.
|
# Older Synapses returned a 400 error.
|
||||||
e.code == 400
|
e.code == 400 and synapse_error.errcode == Codes.UNRECOGNIZED
|
||||||
and synapse_error.errcode == Codes.UNRECOGNIZED
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -233,7 +233,7 @@ def return_html_error(
|
||||||
|
|
||||||
|
|
||||||
def wrap_async_request_handler(
|
def wrap_async_request_handler(
|
||||||
h: Callable[["_AsyncResource", "SynapseRequest"], Awaitable[None]]
|
h: Callable[["_AsyncResource", "SynapseRequest"], Awaitable[None]],
|
||||||
) -> Callable[["_AsyncResource", "SynapseRequest"], "defer.Deferred[None]"]:
|
) -> Callable[["_AsyncResource", "SynapseRequest"], "defer.Deferred[None]"]:
|
||||||
"""Wraps an async request handler so that it calls request.processing.
|
"""Wraps an async request handler so that it calls request.processing.
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
"""
|
"""
|
||||||
Log formatters that output terse JSON.
|
Log formatters that output terse JSON.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,7 @@ them.
|
||||||
|
|
||||||
See doc/log_contexts.rst for details on how this works.
|
See doc/log_contexts.rst for details on how this works.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import typing
|
import typing
|
||||||
|
@ -751,7 +752,7 @@ def preserve_fn(
|
||||||
f: Union[
|
f: Union[
|
||||||
Callable[P, R],
|
Callable[P, R],
|
||||||
Callable[P, Awaitable[R]],
|
Callable[P, Awaitable[R]],
|
||||||
]
|
],
|
||||||
) -> Callable[P, "defer.Deferred[R]"]:
|
) -> Callable[P, "defer.Deferred[R]"]:
|
||||||
"""Function decorator which wraps the function with run_in_background"""
|
"""Function decorator which wraps the function with run_in_background"""
|
||||||
|
|
||||||
|
|
|
@ -169,6 +169,7 @@ Gotchas
|
||||||
than one caller? Will all of those calling functions have be in a context
|
than one caller? Will all of those calling functions have be in a context
|
||||||
with an active span?
|
with an active span?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import enum
|
import enum
|
||||||
import inspect
|
import inspect
|
||||||
|
@ -414,7 +415,7 @@ def ensure_active_span(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def ensure_active_span_inner_1(
|
def ensure_active_span_inner_1(
|
||||||
func: Callable[P, R]
|
func: Callable[P, R],
|
||||||
) -> Callable[P, Union[Optional[T], R]]:
|
) -> Callable[P, Union[Optional[T], R]]:
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def ensure_active_span_inner_2(
|
def ensure_active_span_inner_2(
|
||||||
|
@ -700,7 +701,7 @@ def set_operation_name(operation_name: str) -> None:
|
||||||
|
|
||||||
@only_if_tracing
|
@only_if_tracing
|
||||||
def force_tracing(
|
def force_tracing(
|
||||||
span: Union["opentracing.Span", _Sentinel] = _Sentinel.sentinel
|
span: Union["opentracing.Span", _Sentinel] = _Sentinel.sentinel,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Force sampling for the active/given span and its children.
|
"""Force sampling for the active/given span and its children.
|
||||||
|
|
||||||
|
@ -1093,9 +1094,10 @@ def trace_servlet(
|
||||||
|
|
||||||
# Mypy seems to think that start_context.tag below can be Optional[str], but
|
# Mypy seems to think that start_context.tag below can be Optional[str], but
|
||||||
# that doesn't appear to be correct and works in practice.
|
# that doesn't appear to be correct and works in practice.
|
||||||
request_tags[
|
|
||||||
SynapseTags.REQUEST_TAG
|
request_tags[SynapseTags.REQUEST_TAG] = (
|
||||||
] = request.request_metrics.start_context.tag # type: ignore[assignment]
|
request.request_metrics.start_context.tag # type: ignore[assignment]
|
||||||
|
)
|
||||||
|
|
||||||
# set the tags *after* the servlet completes, in case it decided to
|
# set the tags *after* the servlet completes, in case it decided to
|
||||||
# prioritise the span (tags will get dropped on unprioritised spans)
|
# prioritise the span (tags will get dropped on unprioritised spans)
|
||||||
|
|
|
@ -293,7 +293,7 @@ def wrap_as_background_process(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrap_as_background_process_inner(
|
def wrap_as_background_process_inner(
|
||||||
func: Callable[P, Awaitable[Optional[R]]]
|
func: Callable[P, Awaitable[Optional[R]]],
|
||||||
) -> Callable[P, "defer.Deferred[Optional[R]]"]:
|
) -> Callable[P, "defer.Deferred[Optional[R]]"]:
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrap_as_background_process_inner_2(
|
def wrap_as_background_process_inner_2(
|
||||||
|
|
|
@ -304,9 +304,9 @@ class BulkPushRuleEvaluator:
|
||||||
if relation_type == "m.thread" and event.content.get(
|
if relation_type == "m.thread" and event.content.get(
|
||||||
"m.relates_to", {}
|
"m.relates_to", {}
|
||||||
).get("is_falling_back", False):
|
).get("is_falling_back", False):
|
||||||
related_events["m.in_reply_to"][
|
related_events["m.in_reply_to"]["im.vector.is_falling_back"] = (
|
||||||
"im.vector.is_falling_back"
|
""
|
||||||
] = ""
|
)
|
||||||
|
|
||||||
return related_events
|
return related_events
|
||||||
|
|
||||||
|
@ -372,7 +372,8 @@ class BulkPushRuleEvaluator:
|
||||||
gather_results(
|
gather_results(
|
||||||
(
|
(
|
||||||
run_in_background( # type: ignore[call-arg]
|
run_in_background( # type: ignore[call-arg]
|
||||||
self.store.get_number_joined_users_in_room, event.room_id # type: ignore[arg-type]
|
self.store.get_number_joined_users_in_room,
|
||||||
|
event.room_id, # type: ignore[arg-type]
|
||||||
),
|
),
|
||||||
run_in_background(
|
run_in_background(
|
||||||
self._get_power_levels_and_sender_level,
|
self._get_power_levels_and_sender_level,
|
||||||
|
|
|
@ -119,7 +119,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
async def _handle_request(self, request: Request, content: JsonDict) -> Tuple[int, JsonDict]: # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
|
self, request: Request, content: JsonDict
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
with Measure(self.clock, "repl_fed_send_events_parse"):
|
with Measure(self.clock, "repl_fed_send_events_parse"):
|
||||||
room_id = content["room_id"]
|
room_id = content["room_id"]
|
||||||
backfilled = content["backfilled"]
|
backfilled = content["backfilled"]
|
||||||
|
|
|
@ -98,7 +98,9 @@ class ReplicationCopyPusherRestServlet(ReplicationEndpoint):
|
||||||
self._store = hs.get_datastores().main
|
self._store = hs.get_datastores().main
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload(user_id: str, old_room_id: str, new_room_id: str) -> JsonDict: # type: ignore[override]
|
async def _serialize_payload( # type: ignore[override]
|
||||||
|
user_id: str, old_room_id: str, new_room_id: str
|
||||||
|
) -> JsonDict:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore[override]
|
async def _handle_request( # type: ignore[override]
|
||||||
|
@ -109,7 +111,6 @@ class ReplicationCopyPusherRestServlet(ReplicationEndpoint):
|
||||||
old_room_id: str,
|
old_room_id: str,
|
||||||
new_room_id: str,
|
new_room_id: str,
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
|
|
||||||
await self._store.copy_push_rules_from_room_to_room_for_user(
|
await self._store.copy_push_rules_from_room_to_room_for_user(
|
||||||
old_room_id, new_room_id, user_id
|
old_room_id, new_room_id, user_id
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,8 +18,8 @@
|
||||||
# [This file includes modifications made by New Vector Limited]
|
# [This file includes modifications made by New Vector Limited]
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
"""A replication client for use by synapse workers.
|
"""A replication client for use by synapse workers."""
|
||||||
"""
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Dict, Iterable, Optional, Set, Tuple
|
from typing import TYPE_CHECKING, Dict, Iterable, Optional, Set, Tuple
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are
|
The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are
|
||||||
allowed to be sent by which side.
|
allowed to be sent by which side.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional, Tuple, Type, TypeVar
|
from typing import List, Optional, Tuple, Type, TypeVar
|
||||||
|
|
|
@ -857,7 +857,7 @@ UpdateRow = TypeVar("UpdateRow")
|
||||||
|
|
||||||
|
|
||||||
def _batch_updates(
|
def _batch_updates(
|
||||||
updates: Iterable[Tuple[UpdateToken, UpdateRow]]
|
updates: Iterable[Tuple[UpdateToken, UpdateRow]],
|
||||||
) -> Iterator[Tuple[UpdateToken, List[UpdateRow]]]:
|
) -> Iterator[Tuple[UpdateToken, List[UpdateRow]]]:
|
||||||
"""Collect stream updates with the same token together
|
"""Collect stream updates with the same token together
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@ protocols.
|
||||||
|
|
||||||
An explanation of this protocol is available in docs/tcp_replication.md
|
An explanation of this protocol is available in docs/tcp_replication.md
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import fcntl
|
import fcntl
|
||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
|
|
|
@ -18,8 +18,7 @@
|
||||||
# [This file includes modifications made by New Vector Limited]
|
# [This file includes modifications made by New Vector Limited]
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
"""The server side of the replication stream.
|
"""The server side of the replication stream."""
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
@ -307,7 +306,7 @@ class ReplicationStreamer:
|
||||||
|
|
||||||
|
|
||||||
def _batch_updates(
|
def _batch_updates(
|
||||||
updates: List[Tuple[Token, StreamRow]]
|
updates: List[Tuple[Token, StreamRow]],
|
||||||
) -> List[Tuple[Optional[Token], StreamRow]]:
|
) -> List[Tuple[Optional[Token], StreamRow]]:
|
||||||
"""Takes a list of updates of form [(token, row)] and sets the token to
|
"""Takes a list of updates of form [(token, row)] and sets the token to
|
||||||
None for all rows where the next row has the same token. This is used to
|
None for all rows where the next row has the same token. This is used to
|
||||||
|
|
|
@ -247,7 +247,7 @@ class _StreamFromIdGen(Stream):
|
||||||
|
|
||||||
|
|
||||||
def current_token_without_instance(
|
def current_token_without_instance(
|
||||||
current_token: Callable[[], int]
|
current_token: Callable[[], int],
|
||||||
) -> Callable[[str], int]:
|
) -> Callable[[str], int]:
|
||||||
"""Takes a current token callback function for a single writer stream
|
"""Takes a current token callback function for a single writer stream
|
||||||
that doesn't take an instance name parameter and wraps it in a function that
|
that doesn't take an instance name parameter and wraps it in a function that
|
||||||
|
|
|
@ -181,8 +181,7 @@ class NewRegistrationTokenRestServlet(RestServlet):
|
||||||
|
|
||||||
uses_allowed = body.get("uses_allowed", None)
|
uses_allowed = body.get("uses_allowed", None)
|
||||||
if not (
|
if not (
|
||||||
uses_allowed is None
|
uses_allowed is None or (type(uses_allowed) is int and uses_allowed >= 0) # noqa: E721
|
||||||
or (type(uses_allowed) is int and uses_allowed >= 0) # noqa: E721
|
|
||||||
):
|
):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
HTTPStatus.BAD_REQUEST,
|
HTTPStatus.BAD_REQUEST,
|
||||||
|
|
|
@ -19,8 +19,8 @@
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
|
|
||||||
"""This module contains base REST classes for constructing client v1 servlets.
|
"""This module contains base REST classes for constructing client v1 servlets."""
|
||||||
"""
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar, cast
|
from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar, cast
|
||||||
|
|
|
@ -108,9 +108,9 @@ class AccountDataServlet(RestServlet):
|
||||||
|
|
||||||
# Push rules are stored in a separate table and must be queried separately.
|
# Push rules are stored in a separate table and must be queried separately.
|
||||||
if account_data_type == AccountDataTypes.PUSH_RULES:
|
if account_data_type == AccountDataTypes.PUSH_RULES:
|
||||||
account_data: Optional[JsonMapping] = (
|
account_data: Optional[
|
||||||
await self._push_rules_handler.push_rules_for_user(requester.user)
|
JsonMapping
|
||||||
)
|
] = await self._push_rules_handler.push_rules_for_user(requester.user)
|
||||||
else:
|
else:
|
||||||
account_data = await self.store.get_global_account_data_by_type_for_user(
|
account_data = await self.store.get_global_account_data_by_type_for_user(
|
||||||
user_id, account_data_type
|
user_id, account_data_type
|
||||||
|
|
|
@ -48,9 +48,7 @@ class AccountValidityRenewServlet(RestServlet):
|
||||||
self.account_renewed_template = (
|
self.account_renewed_template = (
|
||||||
hs.config.account_validity.account_validity_account_renewed_template
|
hs.config.account_validity.account_validity_account_renewed_template
|
||||||
)
|
)
|
||||||
self.account_previously_renewed_template = (
|
self.account_previously_renewed_template = hs.config.account_validity.account_validity_account_previously_renewed_template
|
||||||
hs.config.account_validity.account_validity_account_previously_renewed_template
|
|
||||||
)
|
|
||||||
self.invalid_token_template = (
|
self.invalid_token_template = (
|
||||||
hs.config.account_validity.account_validity_invalid_token_template
|
hs.config.account_validity.account_validity_invalid_token_template
|
||||||
)
|
)
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
"""This module contains REST servlets to do with event streaming, /events."""
|
"""This module contains REST servlets to do with event streaming, /events."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
||||||
|
|
||||||
|
|
|
@ -19,8 +19,8 @@
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
|
|
||||||
""" This module contains REST servlets to do with presence: /presence/<paths>
|
"""This module contains REST servlets to do with presence: /presence/<paths>"""
|
||||||
"""
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Tuple
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
|
|
|
@ -640,14 +640,12 @@ class RegisterRestServlet(RestServlet):
|
||||||
if not password_hash:
|
if not password_hash:
|
||||||
raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
|
raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
|
||||||
|
|
||||||
desired_username = (
|
desired_username = await (
|
||||||
await (
|
|
||||||
self.password_auth_provider.get_username_for_registration(
|
self.password_auth_provider.get_username_for_registration(
|
||||||
auth_result,
|
auth_result,
|
||||||
params,
|
params,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if desired_username is None:
|
if desired_username is None:
|
||||||
desired_username = params.get("username", None)
|
desired_username = params.get("username", None)
|
||||||
|
@ -696,13 +694,11 @@ class RegisterRestServlet(RestServlet):
|
||||||
session_id
|
session_id
|
||||||
)
|
)
|
||||||
|
|
||||||
display_name = (
|
display_name = await (
|
||||||
await (
|
|
||||||
self.password_auth_provider.get_displayname_for_registration(
|
self.password_auth_provider.get_displayname_for_registration(
|
||||||
auth_result, params
|
auth_result, params
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
registered_user_id = await self.registration_handler.register_user(
|
registered_user_id = await self.registration_handler.register_user(
|
||||||
localpart=desired_username,
|
localpart=desired_username,
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
"""This module contains REST servlets to do with rooms: /rooms/<paths>"""
|
"""This module contains REST servlets to do with rooms: /rooms/<paths>"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
|
@ -1045,9 +1045,9 @@ class SlidingSyncRestServlet(RestServlet):
|
||||||
serialized_rooms[room_id]["initial"] = room_result.initial
|
serialized_rooms[room_id]["initial"] = room_result.initial
|
||||||
|
|
||||||
if room_result.unstable_expanded_timeline:
|
if room_result.unstable_expanded_timeline:
|
||||||
serialized_rooms[room_id][
|
serialized_rooms[room_id]["unstable_expanded_timeline"] = (
|
||||||
"unstable_expanded_timeline"
|
room_result.unstable_expanded_timeline
|
||||||
] = room_result.unstable_expanded_timeline
|
)
|
||||||
|
|
||||||
# This will be omitted for invite/knock rooms with `stripped_state`
|
# This will be omitted for invite/knock rooms with `stripped_state`
|
||||||
if (
|
if (
|
||||||
|
@ -1082,9 +1082,9 @@ class SlidingSyncRestServlet(RestServlet):
|
||||||
|
|
||||||
# This will be omitted for invite/knock rooms with `stripped_state`
|
# This will be omitted for invite/knock rooms with `stripped_state`
|
||||||
if room_result.prev_batch is not None:
|
if room_result.prev_batch is not None:
|
||||||
serialized_rooms[room_id]["prev_batch"] = (
|
serialized_rooms[room_id][
|
||||||
await room_result.prev_batch.to_string(self.store)
|
"prev_batch"
|
||||||
)
|
] = await room_result.prev_batch.to_string(self.store)
|
||||||
|
|
||||||
# This will be omitted for invite/knock rooms with `stripped_state`
|
# This will be omitted for invite/knock rooms with `stripped_state`
|
||||||
if room_result.num_live is not None:
|
if room_result.num_live is not None:
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
|
|
||||||
"""This module contains logic for storing HTTP PUT transactions. This is used
|
"""This module contains logic for storing HTTP PUT transactions. This is used
|
||||||
to ensure idempotency when performing PUTs using the REST API."""
|
to ensure idempotency when performing PUTs using the REST API."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Hashable, Tuple
|
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Hashable, Tuple
|
||||||
|
|
||||||
|
|
|
@ -191,11 +191,11 @@ class RemoteKey(RestServlet):
|
||||||
server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {}
|
server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {}
|
||||||
for server_name, key_ids in query.items():
|
for server_name, key_ids in query.items():
|
||||||
if key_ids:
|
if key_ids:
|
||||||
results: Mapping[str, Optional[FetchKeyResultForRemote]] = (
|
results: Mapping[
|
||||||
await self.store.get_server_keys_json_for_remote(
|
str, Optional[FetchKeyResultForRemote]
|
||||||
|
] = await self.store.get_server_keys_json_for_remote(
|
||||||
server_name, key_ids
|
server_name, key_ids
|
||||||
)
|
)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
results = await self.store.get_all_server_keys_json_for_remote(
|
results = await self.store.get_all_server_keys_json_for_remote(
|
||||||
server_name
|
server_name
|
||||||
|
|
|
@ -65,9 +65,9 @@ class WellKnownBuilder:
|
||||||
}
|
}
|
||||||
account_management_url = await auth.account_management_url()
|
account_management_url = await auth.account_management_url()
|
||||||
if account_management_url is not None:
|
if account_management_url is not None:
|
||||||
result["org.matrix.msc2965.authentication"][
|
result["org.matrix.msc2965.authentication"]["account"] = (
|
||||||
"account"
|
account_management_url
|
||||||
] = account_management_url
|
)
|
||||||
|
|
||||||
if self._config.server.extra_well_known_client_content:
|
if self._config.server.extra_well_known_client_content:
|
||||||
for (
|
for (
|
||||||
|
|
|
@ -119,7 +119,9 @@ class ResourceLimitsServerNotices:
|
||||||
elif not currently_blocked and limit_msg:
|
elif not currently_blocked and limit_msg:
|
||||||
# Room is not notifying of a block, when it ought to be.
|
# Room is not notifying of a block, when it ought to be.
|
||||||
await self._apply_limit_block_notification(
|
await self._apply_limit_block_notification(
|
||||||
user_id, limit_msg, limit_type # type: ignore
|
user_id,
|
||||||
|
limit_msg,
|
||||||
|
limit_type, # type: ignore
|
||||||
)
|
)
|
||||||
except SynapseError as e:
|
except SynapseError as e:
|
||||||
logger.error("Error sending resource limits server notice: %s", e)
|
logger.error("Error sending resource limits server notice: %s", e)
|
||||||
|
|
|
@ -416,7 +416,7 @@ class EventsPersistenceStorageController:
|
||||||
set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled))
|
set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled))
|
||||||
|
|
||||||
async def enqueue(
|
async def enqueue(
|
||||||
item: Tuple[str, List[Tuple[EventBase, EventContext]]]
|
item: Tuple[str, List[Tuple[EventBase, EventContext]]],
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, str]:
|
||||||
room_id, evs_ctxs = item
|
room_id, evs_ctxs = item
|
||||||
return await self._event_persist_queue.add_to_queue(
|
return await self._event_persist_queue.add_to_queue(
|
||||||
|
@ -792,9 +792,9 @@ class EventsPersistenceStorageController:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove any events which are prev_events of any existing events.
|
# Remove any events which are prev_events of any existing events.
|
||||||
existing_prevs: Collection[str] = (
|
existing_prevs: Collection[
|
||||||
await self.persist_events_store._get_events_which_are_prevs(result)
|
str
|
||||||
)
|
] = await self.persist_events_store._get_events_which_are_prevs(result)
|
||||||
result.difference_update(existing_prevs)
|
result.difference_update(existing_prevs)
|
||||||
|
|
||||||
# Finally handle the case where the new events have soft-failed prev
|
# Finally handle the case where the new events have soft-failed prev
|
||||||
|
|
|
@ -238,9 +238,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
|
||||||
INNER JOIN user_ips USING (user_id, access_token, ip)
|
INNER JOIN user_ips USING (user_id, access_token, ip)
|
||||||
GROUP BY user_id, access_token, ip
|
GROUP BY user_id, access_token, ip
|
||||||
HAVING count(*) > 1
|
HAVING count(*) > 1
|
||||||
""".format(
|
""".format(clause),
|
||||||
clause
|
|
||||||
),
|
|
||||||
args,
|
args,
|
||||||
)
|
)
|
||||||
res = cast(
|
res = cast(
|
||||||
|
@ -373,9 +371,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
) c
|
) c
|
||||||
INNER JOIN user_ips AS u USING (user_id, device_id, last_seen)
|
INNER JOIN user_ips AS u USING (user_id, device_id, last_seen)
|
||||||
""" % {
|
""" % {"where_clause": where_clause}
|
||||||
"where_clause": where_clause
|
|
||||||
}
|
|
||||||
txn.execute(sql, where_args + [batch_size])
|
txn.execute(sql, where_args + [batch_size])
|
||||||
|
|
||||||
rows = cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
|
rows = cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
|
||||||
|
|
|
@ -1116,7 +1116,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
|
||||||
|
|
||||||
txn.execute(sql, (start, stop))
|
txn.execute(sql, (start, stop))
|
||||||
|
|
||||||
destinations = {d for d, in txn}
|
destinations = {d for (d,) in txn}
|
||||||
to_remove = set()
|
to_remove = set()
|
||||||
for d in destinations:
|
for d in destinations:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -670,9 +670,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
result["keys"] = keys
|
result["keys"] = keys
|
||||||
|
|
||||||
device_display_name = None
|
device_display_name = None
|
||||||
if (
|
if self.hs.config.federation.allow_device_name_lookup_over_federation:
|
||||||
self.hs.config.federation.allow_device_name_lookup_over_federation
|
|
||||||
):
|
|
||||||
device_display_name = device.display_name
|
device_display_name = device.display_name
|
||||||
if device_display_name:
|
if device_display_name:
|
||||||
result["device_display_name"] = device_display_name
|
result["device_display_name"] = device_display_name
|
||||||
|
@ -917,7 +915,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
from_key,
|
from_key,
|
||||||
to_key,
|
to_key,
|
||||||
)
|
)
|
||||||
return {u for u, in rows}
|
return {u for (u,) in rows}
|
||||||
|
|
||||||
@cancellable
|
@cancellable
|
||||||
async def get_users_whose_devices_changed(
|
async def get_users_whose_devices_changed(
|
||||||
|
@ -968,7 +966,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
txn.database_engine, "user_id", chunk
|
txn.database_engine, "user_id", chunk
|
||||||
)
|
)
|
||||||
txn.execute(sql % (clause,), [from_key, to_key] + args)
|
txn.execute(sql % (clause,), [from_key, to_key] + args)
|
||||||
changes.update(user_id for user_id, in txn)
|
changes.update(user_id for (user_id,) in txn)
|
||||||
|
|
||||||
return changes
|
return changes
|
||||||
|
|
||||||
|
@ -1520,7 +1518,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
args: List[Any],
|
args: List[Any],
|
||||||
) -> Set[str]:
|
) -> Set[str]:
|
||||||
txn.execute(sql.format(clause=clause), args)
|
txn.execute(sql.format(clause=clause), args)
|
||||||
return {user_id for user_id, in txn}
|
return {user_id for (user_id,) in txn}
|
||||||
|
|
||||||
changes = set()
|
changes = set()
|
||||||
for chunk in batch_iter(changed_room_ids, 1000):
|
for chunk in batch_iter(changed_room_ids, 1000):
|
||||||
|
@ -1560,7 +1558,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
) -> Set[str]:
|
) -> Set[str]:
|
||||||
txn.execute(sql, (from_id, to_id))
|
txn.execute(sql, (from_id, to_id))
|
||||||
return {room_id for room_id, in txn}
|
return {room_id for (room_id,) in txn}
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_all_device_list_changes",
|
"get_all_device_list_changes",
|
||||||
|
|
|
@ -387,9 +387,7 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
|
||||||
is_verified, session_data
|
is_verified, session_data
|
||||||
FROM e2e_room_keys
|
FROM e2e_room_keys
|
||||||
WHERE user_id = ? AND version = ? AND (%s)
|
WHERE user_id = ? AND version = ? AND (%s)
|
||||||
""" % (
|
""" % (" OR ".join(where_clauses))
|
||||||
" OR ".join(where_clauses)
|
|
||||||
)
|
|
||||||
|
|
||||||
txn.execute(sql, params)
|
txn.execute(sql, params)
|
||||||
|
|
||||||
|
|
|
@ -472,9 +472,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||||
signature_sql = """
|
signature_sql = """
|
||||||
SELECT user_id, key_id, target_device_id, signature
|
SELECT user_id, key_id, target_device_id, signature
|
||||||
FROM e2e_cross_signing_signatures WHERE %s
|
FROM e2e_cross_signing_signatures WHERE %s
|
||||||
""" % (
|
""" % (" OR ".join("(" + q + ")" for q in signature_query_clauses))
|
||||||
" OR ".join("(" + q + ")" for q in signature_query_clauses)
|
|
||||||
)
|
|
||||||
|
|
||||||
txn.execute(signature_sql, signature_query_params)
|
txn.execute(signature_sql, signature_query_params)
|
||||||
return cast(
|
return cast(
|
||||||
|
@ -917,9 +915,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||||
FROM e2e_cross_signing_keys
|
FROM e2e_cross_signing_keys
|
||||||
WHERE %(clause)s
|
WHERE %(clause)s
|
||||||
ORDER BY user_id, keytype, stream_id DESC
|
ORDER BY user_id, keytype, stream_id DESC
|
||||||
""" % {
|
""" % {"clause": clause}
|
||||||
"clause": clause
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
# SQLite has special handling for bare columns when using
|
# SQLite has special handling for bare columns when using
|
||||||
# MIN/MAX with a `GROUP BY` clause where it picks the value from
|
# MIN/MAX with a `GROUP BY` clause where it picks the value from
|
||||||
|
@ -929,9 +925,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||||
FROM e2e_cross_signing_keys
|
FROM e2e_cross_signing_keys
|
||||||
WHERE %(clause)s
|
WHERE %(clause)s
|
||||||
GROUP BY user_id, keytype
|
GROUP BY user_id, keytype
|
||||||
""" % {
|
""" % {"clause": clause}
|
||||||
"clause": clause
|
|
||||||
}
|
|
||||||
|
|
||||||
txn.execute(sql, params)
|
txn.execute(sql, params)
|
||||||
|
|
||||||
|
|
|
@ -326,7 +326,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rows = txn.execute_values(sql, chains.items())
|
rows = txn.execute_values(sql, chains.items())
|
||||||
results.update(r for r, in rows)
|
results.update(r for (r,) in rows)
|
||||||
else:
|
else:
|
||||||
# For SQLite we just fall back to doing a noddy for loop.
|
# For SQLite we just fall back to doing a noddy for loop.
|
||||||
sql = """
|
sql = """
|
||||||
|
@ -335,7 +335,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
"""
|
"""
|
||||||
for chain_id, max_no in chains.items():
|
for chain_id, max_no in chains.items():
|
||||||
txn.execute(sql, (chain_id, max_no))
|
txn.execute(sql, (chain_id, max_no))
|
||||||
results.update(r for r, in txn)
|
results.update(r for (r,) in txn)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -645,7 +645,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
]
|
]
|
||||||
|
|
||||||
rows = txn.execute_values(sql, args)
|
rows = txn.execute_values(sql, args)
|
||||||
result.update(r for r, in rows)
|
result.update(r for (r,) in rows)
|
||||||
else:
|
else:
|
||||||
# For SQLite we just fall back to doing a noddy for loop.
|
# For SQLite we just fall back to doing a noddy for loop.
|
||||||
sql = """
|
sql = """
|
||||||
|
@ -654,7 +654,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
"""
|
"""
|
||||||
for chain_id, (min_no, max_no) in chain_to_gap.items():
|
for chain_id, (min_no, max_no) in chain_to_gap.items():
|
||||||
txn.execute(sql, (chain_id, min_no, max_no))
|
txn.execute(sql, (chain_id, min_no, max_no))
|
||||||
result.update(r for r, in txn)
|
result.update(r for (r,) in txn)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -1220,13 +1220,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
HAVING count(*) > ?
|
HAVING count(*) > ?
|
||||||
ORDER BY count(*) DESC
|
ORDER BY count(*) DESC
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
""" % (
|
""" % (where_clause,)
|
||||||
where_clause,
|
|
||||||
)
|
|
||||||
|
|
||||||
query_args = list(itertools.chain(room_id_filter, [min_count, limit]))
|
query_args = list(itertools.chain(room_id_filter, [min_count, limit]))
|
||||||
txn.execute(sql, query_args)
|
txn.execute(sql, query_args)
|
||||||
return [room_id for room_id, in txn]
|
return [room_id for (room_id,) in txn]
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn
|
"get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn
|
||||||
|
@ -1358,7 +1356,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
|
|
||||||
def get_forward_extremeties_for_room_txn(txn: LoggingTransaction) -> List[str]:
|
def get_forward_extremeties_for_room_txn(txn: LoggingTransaction) -> List[str]:
|
||||||
txn.execute(sql, (stream_ordering, room_id))
|
txn.execute(sql, (stream_ordering, room_id))
|
||||||
return [event_id for event_id, in txn]
|
return [event_id for (event_id,) in txn]
|
||||||
|
|
||||||
event_ids = await self.db_pool.runInteraction(
|
event_ids = await self.db_pool.runInteraction(
|
||||||
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
|
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
|
||||||
|
|
|
@ -1860,9 +1860,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||||
AND epa.notif = 1
|
AND epa.notif = 1
|
||||||
ORDER BY epa.stream_ordering DESC
|
ORDER BY epa.stream_ordering DESC
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
""" % (
|
""" % (before_clause,)
|
||||||
before_clause,
|
|
||||||
)
|
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
return cast(
|
return cast(
|
||||||
List[Tuple[str, str, int, int, str, bool, str, int]], txn.fetchall()
|
List[Tuple[str, str, int, int, str, bool, str, int]], txn.fetchall()
|
||||||
|
|
|
@ -429,9 +429,7 @@ class PersistEventsStore:
|
||||||
if event_type == EventTypes.Member and self.is_mine_id(state_key)
|
if event_type == EventTypes.Member and self.is_mine_id(state_key)
|
||||||
]
|
]
|
||||||
|
|
||||||
membership_snapshot_shared_insert_values: (
|
membership_snapshot_shared_insert_values: SlidingSyncMembershipSnapshotSharedInsertValues = {}
|
||||||
SlidingSyncMembershipSnapshotSharedInsertValues
|
|
||||||
) = {}
|
|
||||||
membership_infos_to_insert_membership_snapshots: List[
|
membership_infos_to_insert_membership_snapshots: List[
|
||||||
SlidingSyncMembershipInfo
|
SlidingSyncMembershipInfo
|
||||||
] = []
|
] = []
|
||||||
|
@ -719,7 +717,7 @@ class PersistEventsStore:
|
||||||
keyvalues={},
|
keyvalues={},
|
||||||
retcols=("event_id",),
|
retcols=("event_id",),
|
||||||
)
|
)
|
||||||
already_persisted_events = {event_id for event_id, in rows}
|
already_persisted_events = {event_id for (event_id,) in rows}
|
||||||
state_events = [
|
state_events = [
|
||||||
event
|
event
|
||||||
for event in state_events
|
for event in state_events
|
||||||
|
@ -1830,12 +1828,8 @@ class PersistEventsStore:
|
||||||
if sliding_sync_table_changes.to_insert_membership_snapshots:
|
if sliding_sync_table_changes.to_insert_membership_snapshots:
|
||||||
# Update the `sliding_sync_membership_snapshots` table
|
# Update the `sliding_sync_membership_snapshots` table
|
||||||
#
|
#
|
||||||
sliding_sync_snapshot_keys = (
|
sliding_sync_snapshot_keys = sliding_sync_table_changes.membership_snapshot_shared_insert_values.keys()
|
||||||
sliding_sync_table_changes.membership_snapshot_shared_insert_values.keys()
|
sliding_sync_snapshot_values = sliding_sync_table_changes.membership_snapshot_shared_insert_values.values()
|
||||||
)
|
|
||||||
sliding_sync_snapshot_values = (
|
|
||||||
sliding_sync_table_changes.membership_snapshot_shared_insert_values.values()
|
|
||||||
)
|
|
||||||
# We need to insert/update regardless of whether we have
|
# We need to insert/update regardless of whether we have
|
||||||
# `sliding_sync_snapshot_keys` because there are other fields in the `ON
|
# `sliding_sync_snapshot_keys` because there are other fields in the `ON
|
||||||
# CONFLICT` upsert to run (see inherit case (explained in
|
# CONFLICT` upsert to run (see inherit case (explained in
|
||||||
|
@ -3361,7 +3355,7 @@ class PersistEventsStore:
|
||||||
)
|
)
|
||||||
|
|
||||||
potential_backwards_extremities.difference_update(
|
potential_backwards_extremities.difference_update(
|
||||||
e for e, in existing_events_outliers
|
e for (e,) in existing_events_outliers
|
||||||
)
|
)
|
||||||
|
|
||||||
if potential_backwards_extremities:
|
if potential_backwards_extremities:
|
||||||
|
|
|
@ -647,7 +647,8 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS
|
||||||
room_ids = {row[0] for row in rows}
|
room_ids = {row[0] for row in rows}
|
||||||
for room_id in room_ids:
|
for room_id in room_ids:
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined]
|
self.get_latest_event_ids_in_room.invalidate, # type: ignore[attr-defined]
|
||||||
|
(room_id,),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.db_pool.simple_delete_many_txn(
|
self.db_pool.simple_delete_many_txn(
|
||||||
|
@ -2065,9 +2066,7 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS
|
||||||
)
|
)
|
||||||
|
|
||||||
# Map of values to insert/update in the `sliding_sync_membership_snapshots` table
|
# Map of values to insert/update in the `sliding_sync_membership_snapshots` table
|
||||||
sliding_sync_membership_snapshots_insert_map: (
|
sliding_sync_membership_snapshots_insert_map: SlidingSyncMembershipSnapshotSharedInsertValues = {}
|
||||||
SlidingSyncMembershipSnapshotSharedInsertValues
|
|
||||||
) = {}
|
|
||||||
if membership == Membership.JOIN:
|
if membership == Membership.JOIN:
|
||||||
# If we're still joined, we can pull from current state.
|
# If we're still joined, we can pull from current state.
|
||||||
current_state_ids_map: StateMap[
|
current_state_ids_map: StateMap[
|
||||||
|
@ -2149,15 +2148,16 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS
|
||||||
# membership (i.e. the room shouldn't disappear if your using the
|
# membership (i.e. the room shouldn't disappear if your using the
|
||||||
# `is_encrypted` filter and you leave).
|
# `is_encrypted` filter and you leave).
|
||||||
if membership in (Membership.LEAVE, Membership.BAN) and is_outlier:
|
if membership in (Membership.LEAVE, Membership.BAN) and is_outlier:
|
||||||
invite_or_knock_event_id, invite_or_knock_membership = (
|
(
|
||||||
await self.db_pool.runInteraction(
|
invite_or_knock_event_id,
|
||||||
|
invite_or_knock_membership,
|
||||||
|
) = await self.db_pool.runInteraction(
|
||||||
"sliding_sync_membership_snapshots_bg_update._find_previous_membership",
|
"sliding_sync_membership_snapshots_bg_update._find_previous_membership",
|
||||||
_find_previous_membership_txn,
|
_find_previous_membership_txn,
|
||||||
room_id,
|
room_id,
|
||||||
user_id,
|
user_id,
|
||||||
membership_event_id,
|
membership_event_id,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Pull from the stripped state on the invite/knock event
|
# Pull from the stripped state on the invite/knock event
|
||||||
invite_or_knock_event = await self.get_event(invite_or_knock_event_id)
|
invite_or_knock_event = await self.get_event(invite_or_knock_event_id)
|
||||||
|
@ -2484,9 +2484,7 @@ def _resolve_stale_data_in_sliding_sync_joined_rooms_table(
|
||||||
"progress_json": "{}",
|
"progress_json": "{}",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
depends_on = (
|
depends_on = _BackgroundUpdates.SLIDING_SYNC_PREFILL_JOINED_ROOMS_TO_RECALCULATE_TABLE_BG_UPDATE
|
||||||
_BackgroundUpdates.SLIDING_SYNC_PREFILL_JOINED_ROOMS_TO_RECALCULATE_TABLE_BG_UPDATE
|
|
||||||
)
|
|
||||||
|
|
||||||
# Now kick-off the background update to catch-up with what we missed while Synapse
|
# Now kick-off the background update to catch-up with what we missed while Synapse
|
||||||
# was downgraded.
|
# was downgraded.
|
||||||
|
|
|
@ -1665,7 +1665,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
txn.database_engine, "e.event_id", event_ids
|
txn.database_engine, "e.event_id", event_ids
|
||||||
)
|
)
|
||||||
txn.execute(sql + clause, args)
|
txn.execute(sql + clause, args)
|
||||||
found_events = {eid for eid, in txn}
|
found_events = {eid for (eid,) in txn}
|
||||||
|
|
||||||
# ... and then we can update the results for each key
|
# ... and then we can update the results for each key
|
||||||
return {eid: (eid in found_events) for eid in event_ids}
|
return {eid: (eid in found_events) for eid in event_ids}
|
||||||
|
@ -1864,9 +1864,9 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
" LIMIT ?"
|
" LIMIT ?"
|
||||||
)
|
)
|
||||||
txn.execute(sql, (-last_id, -current_id, instance_name, limit))
|
txn.execute(sql, (-last_id, -current_id, instance_name, limit))
|
||||||
new_event_updates: List[Tuple[int, Tuple[str, str, str, str, str, str]]] = (
|
new_event_updates: List[
|
||||||
[]
|
Tuple[int, Tuple[str, str, str, str, str, str]]
|
||||||
)
|
] = []
|
||||||
row: Tuple[int, str, str, str, str, str, str]
|
row: Tuple[int, str, str, str, str, str, str]
|
||||||
# Type safety: iterating over `txn` yields `Tuple`, i.e.
|
# Type safety: iterating over `txn` yields `Tuple`, i.e.
|
||||||
# `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
|
# `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
|
||||||
|
|
|
@ -201,7 +201,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
|
||||||
txn.execute_batch(
|
txn.execute_batch(
|
||||||
"INSERT INTO event_backward_extremities (room_id, event_id)"
|
"INSERT INTO event_backward_extremities (room_id, event_id)"
|
||||||
" VALUES (?, ?)",
|
" VALUES (?, ?)",
|
||||||
[(room_id, event_id) for event_id, in new_backwards_extrems],
|
[(room_id, event_id) for (event_id,) in new_backwards_extrems],
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("[purge] finding state groups referenced by deleted events")
|
logger.info("[purge] finding state groups referenced by deleted events")
|
||||||
|
@ -215,7 +215,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
referenced_state_groups = {sg for sg, in txn}
|
referenced_state_groups = {sg for (sg,) in txn}
|
||||||
logger.info(
|
logger.info(
|
||||||
"[purge] found %i referenced state groups", len(referenced_state_groups)
|
"[purge] found %i referenced state groups", len(referenced_state_groups)
|
||||||
)
|
)
|
||||||
|
|
|
@ -762,7 +762,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
|
|
||||||
return [room_id for room_id, in txn]
|
return [room_id for (room_id,) in txn]
|
||||||
|
|
||||||
results: List[str] = []
|
results: List[str] = []
|
||||||
for batch in batch_iter(room_ids, 1000):
|
for batch in batch_iter(room_ids, 1000):
|
||||||
|
@ -1030,9 +1030,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||||
SELECT event_id WHERE room_id = ? AND stream_ordering IN (
|
SELECT event_id WHERE room_id = ? AND stream_ordering IN (
|
||||||
SELECT max(stream_ordering) WHERE %s
|
SELECT max(stream_ordering) WHERE %s
|
||||||
)
|
)
|
||||||
""" % (
|
""" % (clause,)
|
||||||
clause,
|
|
||||||
)
|
|
||||||
|
|
||||||
txn.execute(sql, [room_id] + list(args))
|
txn.execute(sql, [room_id] + list(args))
|
||||||
rows = txn.fetchall()
|
rows = txn.fetchall()
|
||||||
|
|
|
@ -1250,9 +1250,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
SELECT address, session_id, medium, client_secret,
|
SELECT address, session_id, medium, client_secret,
|
||||||
last_send_attempt, validated_at
|
last_send_attempt, validated_at
|
||||||
FROM threepid_validation_session WHERE %s
|
FROM threepid_validation_session WHERE %s
|
||||||
""" % (
|
""" % (" AND ".join("%s = ?" % k for k in keyvalues.keys()),)
|
||||||
" AND ".join("%s = ?" % k for k in keyvalues.keys()),
|
|
||||||
)
|
|
||||||
|
|
||||||
if validated is not None:
|
if validated is not None:
|
||||||
sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL")
|
sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL")
|
||||||
|
|
|
@ -1608,9 +1608,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||||
FROM event_reports AS er
|
FROM event_reports AS er
|
||||||
JOIN room_stats_state ON room_stats_state.room_id = er.room_id
|
JOIN room_stats_state ON room_stats_state.room_id = er.room_id
|
||||||
{}
|
{}
|
||||||
""".format(
|
""".format(where_clause)
|
||||||
where_clause
|
|
||||||
)
|
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
count = cast(Tuple[int], txn.fetchone())[0]
|
count = cast(Tuple[int], txn.fetchone())[0]
|
||||||
|
|
||||||
|
|
|
@ -232,9 +232,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
||||||
AND m.room_id = c.room_id
|
AND m.room_id = c.room_id
|
||||||
AND m.user_id = c.state_key
|
AND m.user_id = c.state_key
|
||||||
WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? AND %s
|
WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? AND %s
|
||||||
""" % (
|
""" % (clause,)
|
||||||
clause,
|
|
||||||
)
|
|
||||||
txn.execute(sql, (room_id, Membership.JOIN, *ids))
|
txn.execute(sql, (room_id, Membership.JOIN, *ids))
|
||||||
|
|
||||||
return {r[0]: ProfileInfo(display_name=r[1], avatar_url=r[2]) for r in txn}
|
return {r[0]: ProfileInfo(display_name=r[1], avatar_url=r[2]) for r in txn}
|
||||||
|
@ -531,9 +529,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
||||||
WHERE
|
WHERE
|
||||||
user_id = ?
|
user_id = ?
|
||||||
AND %s
|
AND %s
|
||||||
""" % (
|
""" % (clause,)
|
||||||
clause,
|
|
||||||
)
|
|
||||||
|
|
||||||
txn.execute(sql, (user_id, *args))
|
txn.execute(sql, (user_id, *args))
|
||||||
results = [
|
results = [
|
||||||
|
@ -813,7 +809,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
txn.execute(sql, (user_id, *args))
|
txn.execute(sql, (user_id, *args))
|
||||||
return {u: True for u, in txn}
|
return {u: True for (u,) in txn}
|
||||||
|
|
||||||
to_return = {}
|
to_return = {}
|
||||||
for batch_user_ids in batch_iter(other_user_ids, 1000):
|
for batch_user_ids in batch_iter(other_user_ids, 1000):
|
||||||
|
@ -1031,7 +1027,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
||||||
AND room_id = ?
|
AND room_id = ?
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (room_id,))
|
txn.execute(sql, (room_id,))
|
||||||
return {d for d, in txn}
|
return {d for (d,) in txn}
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_current_hosts_in_room", get_current_hosts_in_room_txn
|
"get_current_hosts_in_room", get_current_hosts_in_room_txn
|
||||||
|
@ -1099,7 +1095,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (room_id,))
|
txn.execute(sql, (room_id,))
|
||||||
# `server_domain` will be `NULL` for malformed MXIDs with no colons.
|
# `server_domain` will be `NULL` for malformed MXIDs with no colons.
|
||||||
return tuple(d for d, in txn if d is not None)
|
return tuple(d for (d,) in txn if d is not None)
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn
|
"get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn
|
||||||
|
@ -1316,9 +1312,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
||||||
room_id = ? AND membership = ?
|
room_id = ? AND membership = ?
|
||||||
AND NOT (%s)
|
AND NOT (%s)
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
""" % (
|
""" % (clause,)
|
||||||
clause,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _is_local_host_in_room_ignoring_users_txn(
|
def _is_local_host_in_room_ignoring_users_txn(
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
|
@ -1464,10 +1458,12 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
||||||
self, progress: JsonDict, batch_size: int
|
self, progress: JsonDict, batch_size: int
|
||||||
) -> int:
|
) -> int:
|
||||||
target_min_stream_id = progress.get(
|
target_min_stream_id = progress.get(
|
||||||
"target_min_stream_id_inclusive", self._min_stream_order_on_start # type: ignore[attr-defined]
|
"target_min_stream_id_inclusive",
|
||||||
|
self._min_stream_order_on_start, # type: ignore[attr-defined]
|
||||||
)
|
)
|
||||||
max_stream_id = progress.get(
|
max_stream_id = progress.get(
|
||||||
"max_stream_id_exclusive", self._stream_order_on_start + 1 # type: ignore[attr-defined]
|
"max_stream_id_exclusive",
|
||||||
|
self._stream_order_on_start + 1, # type: ignore[attr-defined]
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_membership_profile_txn(txn: LoggingTransaction) -> int:
|
def add_membership_profile_txn(txn: LoggingTransaction) -> int:
|
||||||
|
|
|
@ -177,9 +177,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
|
||||||
AND (%s)
|
AND (%s)
|
||||||
ORDER BY stream_ordering DESC
|
ORDER BY stream_ordering DESC
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
""" % (
|
""" % (" OR ".join("type = '%s'" % (t,) for t in TYPES),)
|
||||||
" OR ".join("type = '%s'" % (t,) for t in TYPES),
|
|
||||||
)
|
|
||||||
|
|
||||||
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
|
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
|
||||||
|
|
||||||
|
|
|
@ -535,7 +535,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
desc="check_if_events_in_current_state",
|
desc="check_if_events_in_current_state",
|
||||||
)
|
)
|
||||||
|
|
||||||
return frozenset(event_id for event_id, in rows)
|
return frozenset(event_id for (event_id,) in rows)
|
||||||
|
|
||||||
# FIXME: how should this be cached?
|
# FIXME: how should this be cached?
|
||||||
@cancellable
|
@cancellable
|
||||||
|
|
|
@ -161,7 +161,7 @@ class StatsStore(StateDeltasStore):
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (last_user_id, batch_size))
|
txn.execute(sql, (last_user_id, batch_size))
|
||||||
return [r for r, in txn]
|
return [r for (r,) in txn]
|
||||||
|
|
||||||
users_to_work_on = await self.db_pool.runInteraction(
|
users_to_work_on = await self.db_pool.runInteraction(
|
||||||
"_populate_stats_process_users", _get_next_batch
|
"_populate_stats_process_users", _get_next_batch
|
||||||
|
@ -207,7 +207,7 @@ class StatsStore(StateDeltasStore):
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (last_room_id, batch_size))
|
txn.execute(sql, (last_room_id, batch_size))
|
||||||
return [r for r, in txn]
|
return [r for (r,) in txn]
|
||||||
|
|
||||||
rooms_to_work_on = await self.db_pool.runInteraction(
|
rooms_to_work_on = await self.db_pool.runInteraction(
|
||||||
"populate_stats_rooms_get_batch", _get_next_batch
|
"populate_stats_rooms_get_batch", _get_next_batch
|
||||||
|
@ -751,9 +751,7 @@ class StatsStore(StateDeltasStore):
|
||||||
LEFT JOIN profiles AS p ON lmr.user_id = p.full_user_id
|
LEFT JOIN profiles AS p ON lmr.user_id = p.full_user_id
|
||||||
{}
|
{}
|
||||||
GROUP BY lmr.user_id, displayname
|
GROUP BY lmr.user_id, displayname
|
||||||
""".format(
|
""".format(where_clause)
|
||||||
where_clause
|
|
||||||
)
|
|
||||||
|
|
||||||
# SQLite does not support SELECT COUNT(*) OVER()
|
# SQLite does not support SELECT COUNT(*) OVER()
|
||||||
sql = """
|
sql = """
|
||||||
|
|
|
@ -1122,9 +1122,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
AND e.stream_ordering > ? AND e.stream_ordering <= ?
|
AND e.stream_ordering > ? AND e.stream_ordering <= ?
|
||||||
%s
|
%s
|
||||||
ORDER BY e.stream_ordering ASC
|
ORDER BY e.stream_ordering ASC
|
||||||
""" % (
|
""" % (ignore_room_clause,)
|
||||||
ignore_room_clause,
|
|
||||||
)
|
|
||||||
|
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
|
|
||||||
|
|
|
@ -224,9 +224,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||||
SELECT room_id, events FROM %s
|
SELECT room_id, events FROM %s
|
||||||
ORDER BY events DESC
|
ORDER BY events DESC
|
||||||
LIMIT 250
|
LIMIT 250
|
||||||
""" % (
|
""" % (TEMP_TABLE + "_rooms",)
|
||||||
TEMP_TABLE + "_rooms",
|
|
||||||
)
|
|
||||||
txn.execute(sql)
|
txn.execute(sql)
|
||||||
rooms_to_work_on = cast(List[Tuple[str, int]], txn.fetchall())
|
rooms_to_work_on = cast(List[Tuple[str, int]], txn.fetchall())
|
||||||
|
|
||||||
|
|
|
@ -767,7 +767,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
|
|
||||||
remaining_state_groups = {
|
remaining_state_groups = {
|
||||||
state_group
|
state_group
|
||||||
for state_group, in rows
|
for (state_group,) in rows
|
||||||
if state_group not in state_groups_to_delete
|
if state_group not in state_groups_to_delete
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -607,7 +607,7 @@ def _apply_module_schema_files(
|
||||||
"SELECT file FROM applied_module_schemas WHERE module_name = ?",
|
"SELECT file FROM applied_module_schemas WHERE module_name = ?",
|
||||||
(modname,),
|
(modname,),
|
||||||
)
|
)
|
||||||
applied_deltas = {d for d, in cur}
|
applied_deltas = {d for (d,) in cur}
|
||||||
for name, stream in names_and_streams:
|
for name, stream in names_and_streams:
|
||||||
if name in applied_deltas:
|
if name in applied_deltas:
|
||||||
continue
|
continue
|
||||||
|
@ -710,7 +710,7 @@ def _get_or_create_schema_state(
|
||||||
"SELECT file FROM applied_schema_deltas WHERE version >= ?",
|
"SELECT file FROM applied_schema_deltas WHERE version >= ?",
|
||||||
(current_version,),
|
(current_version,),
|
||||||
)
|
)
|
||||||
applied_deltas = tuple(d for d, in txn)
|
applied_deltas = tuple(d for (d,) in txn)
|
||||||
|
|
||||||
return _SchemaState(
|
return _SchemaState(
|
||||||
current_version=current_version,
|
current_version=current_version,
|
||||||
|
|
|
@ -41,8 +41,6 @@ def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) ->
|
||||||
(user_id, filter_id);
|
(user_id, filter_id);
|
||||||
DROP TABLE user_filters;
|
DROP TABLE user_filters;
|
||||||
ALTER TABLE user_filters_migration RENAME TO user_filters;
|
ALTER TABLE user_filters_migration RENAME TO user_filters;
|
||||||
""" % (
|
""" % (select_clause,)
|
||||||
select_clause,
|
|
||||||
)
|
|
||||||
|
|
||||||
execute_statements_from_stream(cur, StringIO(sql))
|
execute_statements_from_stream(cur, StringIO(sql))
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
This migration handles the process of changing the type of `room_depth.min_depth` to
|
This migration handles the process of changing the type of `room_depth.min_depth` to
|
||||||
a BIGINT.
|
a BIGINT.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from synapse.storage.database import LoggingTransaction
|
from synapse.storage.database import LoggingTransaction
|
||||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,7 @@ This migration adds triggers to the partial_state_events tables to enforce uniqu
|
||||||
|
|
||||||
Triggers cannot be expressed in .sql files, so we have to use a separate file.
|
Triggers cannot be expressed in .sql files, so we have to use a separate file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from synapse.storage.database import LoggingTransaction
|
from synapse.storage.database import LoggingTransaction
|
||||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ for its completion can be removed.
|
||||||
|
|
||||||
Note the background job must still remain defined in the database class.
|
Note the background job must still remain defined in the database class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.storage.database import LoggingTransaction
|
from synapse.storage.database import LoggingTransaction
|
||||||
from synapse.storage.engines import BaseDatabaseEngine
|
from synapse.storage.engines import BaseDatabaseEngine
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
This migration adds triggers to the room membership tables to enforce consistency.
|
This migration adds triggers to the room membership tables to enforce consistency.
|
||||||
Triggers cannot be expressed in .sql files, so we have to use a separate file.
|
Triggers cannot be expressed in .sql files, so we have to use a separate file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from synapse.storage.database import LoggingTransaction
|
from synapse.storage.database import LoggingTransaction
|
||||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
"""
|
"""
|
||||||
This migration adds foreign key constraint to `event_forward_extremities` table.
|
This migration adds foreign key constraint to `event_forward_extremities` table.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from synapse.storage.background_updates import (
|
from synapse.storage.background_updates import (
|
||||||
ForeignKeyConstraint,
|
ForeignKeyConstraint,
|
||||||
run_validate_constraint_and_delete_rows_schema_delta,
|
run_validate_constraint_and_delete_rows_schema_delta,
|
||||||
|
|
|
@ -1308,7 +1308,7 @@ class DeviceListUpdates:
|
||||||
|
|
||||||
|
|
||||||
def get_verify_key_from_cross_signing_key(
|
def get_verify_key_from_cross_signing_key(
|
||||||
key_info: Mapping[str, Any]
|
key_info: Mapping[str, Any],
|
||||||
) -> Tuple[str, VerifyKey]:
|
) -> Tuple[str, VerifyKey]:
|
||||||
"""Get the key ID and signedjson verify key from a cross-signing key dict
|
"""Get the key ID and signedjson verify key from a cross-signing key dict
|
||||||
|
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue