Merge branch 'develop' into rav/module_api_extensions

This commit is contained in:
Richard van der Hoff 2020-01-15 16:00:24 +00:00
commit 107f256cd8
64 changed files with 1281 additions and 295 deletions

View file

@ -1,6 +1,9 @@
Synapse 1.8.0 (2020-01-09)
==========================
**WARNING**: As of this release Synapse will refuse to start if the `log_file` config option is specified. Support for the option was removed in v1.3.0.
Bugfixes
--------

View file

@ -133,6 +133,11 @@ sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \
sudo yum groupinstall "Development Tools"
```
Note that Synapse does not support versions of SQLite before 3.11, and CentOS 7
uses SQLite 3.7. You may be able to work around this by installing a more
recent SQLite version, but it is recommended that you instead use a Postgres
database: see [docs/postgres.md](docs/postgres.md).
#### macOS
Installing prerequisites on macOS:

View file

@ -75,6 +75,15 @@ for example:
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
Upgrading to v1.8.0
===================
Specifying a ``log_file`` config option will now cause Synapse to refuse to
start, and should be replaced by with the ``log_config`` option. Support for
the ``log_file`` option was removed in v1.3.0 and has since had no effect.
Upgrading to v1.7.0
===================

1
changelog.d/6655.misc Normal file
View file

@ -0,0 +1 @@
Add `local_current_membership` table for tracking local user membership state in rooms.

1
changelog.d/6667.misc Normal file
View file

@ -0,0 +1 @@
Fixup `synapse.replication` to pass mypy checks.

1
changelog.d/6675.removal Normal file
View file

@ -0,0 +1 @@
Synapse no longer supports versions of SQLite before 3.11, and will refuse to start when configured to use an older version. Administrators are recommended to migrate their database to Postgres (see instructions [here](docs/postgres.md)).

1
changelog.d/6681.feature Normal file
View file

@ -0,0 +1 @@
Add new quarantine media admin APIs to quarantine by media ID or by user who uploaded the media.

2
changelog.d/6682.bugfix Normal file
View file

@ -0,0 +1,2 @@
Fix "CRITICAL" errors being logged when a request is received for a uri containing non-ascii characters.

1
changelog.d/6686.misc Normal file
View file

@ -0,0 +1 @@
Allow additional_resources to implement IResource directly.

1
changelog.d/6687.misc Normal file
View file

@ -0,0 +1 @@
Allow REST endpoint implementations to raise a RedirectException, which will redirect the user's browser to a given location.

1
changelog.d/6689.misc Normal file
View file

@ -0,0 +1 @@
Updates to the SAML mapping provider API.

1
changelog.d/6690.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug where we would assign a numeric userid if somebody tried registering with an empty username.

1
changelog.d/6691.misc Normal file
View file

@ -0,0 +1 @@
Remove redundant RegistrationError class.

1
changelog.d/6697.misc Normal file
View file

@ -0,0 +1 @@
Don't block processing of incoming EDUs behind processing PDUs in the same transaction.

1
changelog.d/6698.doc Normal file
View file

@ -0,0 +1 @@
Add more endpoints to the documentation for Synapse workers.

View file

@ -22,19 +22,81 @@ It returns a JSON body like the following:
}
```
# Quarantine media in a room
This API 'quarantines' all the media in a room.
The API is:
```
POST /_synapse/admin/v1/quarantine_media/<room_id>
{}
```
# Quarantine media
Quarantining media means that it is marked as inaccessible by users. It applies
to any local media, and any locally-cached copies of remote media.
The media file itself (and any thumbnails) is not deleted from the server.
## Quarantining media by ID
This API quarantines a single piece of local or remote media.
Request:
```
POST /_synapse/admin/v1/media/quarantine/<server_name>/<media_id>
{}
```
Where `server_name` is in the form of `example.org`, and `media_id` is in the
form of `abcdefg12345...`.
Response:
```
{}
```
## Quarantining media in a room
This API quarantines all local and remote media in a room.
Request:
```
POST /_synapse/admin/v1/room/<room_id>/media/quarantine
{}
```
Where `room_id` is in the form of `!roomid12345:example.org`.
Response:
```
{
"num_quarantined": 10 # The number of media items successfully quarantined
}
```
Note that there is a legacy endpoint, `POST
/_synapse/admin/v1/quarantine_media/<room_id >`, that operates the same.
However, it is deprecated and may be removed in a future release.
## Quarantining all media of a user
This API quarantines all *local* media that a *local* user has uploaded. That is to say, if
you would like to quarantine media uploaded by a user on a remote homeserver, you should
instead use one of the other APIs.
Request:
```
POST /_synapse/admin/v1/user/<user_id>/media/quarantine
{}
```
Where `user_id` is in the form of `@bob:example.org`.
Response:
```
{
"num_quarantined": 10 # The number of media items successfully quarantined
}
```

View file

@ -168,8 +168,11 @@ endpoints matching the following regular expressions:
^/_matrix/federation/v1/make_join/
^/_matrix/federation/v1/make_leave/
^/_matrix/federation/v1/send_join/
^/_matrix/federation/v2/send_join/
^/_matrix/federation/v1/send_leave/
^/_matrix/federation/v2/send_leave/
^/_matrix/federation/v1/invite/
^/_matrix/federation/v2/invite/
^/_matrix/federation/v1/query_auth/
^/_matrix/federation/v1/event_auth/
^/_matrix/federation/v1/exchange_third_party_invite/
@ -199,7 +202,9 @@ Handles the media repository. It can handle all endpoints starting with:
... and the following regular expressions matching media-specific administration APIs:
^/_synapse/admin/v1/purge_media_cache$
^/_synapse/admin/v1/room/.*/media$
^/_synapse/admin/v1/room/.*/media.*$
^/_synapse/admin/v1/user/.*/media.*$
^/_synapse/admin/v1/media/.*$
^/_synapse/admin/v1/quarantine_media/.*$
You should also set `enable_media_repo: False` in the shared configuration
@ -288,6 +293,7 @@ file. For example:
Handles some event creation. It can handle REST endpoints matching:
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state/
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$
^/_matrix/client/(api/v1|r0|unstable)/join/
^/_matrix/client/(api/v1|r0|unstable)/profile/

View file

@ -447,20 +447,15 @@ class Porter(object):
else:
return
def setup_db(self, db_config: DatabaseConnectionConfig, engine):
db_conn = make_conn(db_config, engine)
prepare_database(db_conn, engine, config=None)
db_conn.commit()
return db_conn
@defer.inlineCallbacks
def build_db_store(self, db_config: DatabaseConnectionConfig):
def build_db_store(
self, db_config: DatabaseConnectionConfig, allow_outdated_version: bool = False,
):
"""Builds and returns a database store using the provided configuration.
Args:
config: The database configuration
db_config: The database configuration
allow_outdated_version: True to suppress errors about the database server
version being too old to run a complete synapse
Returns:
The built Store object.
@ -468,16 +463,16 @@ class Porter(object):
self.progress.set_state("Preparing %s" % db_config.config["name"])
engine = create_engine(db_config.config)
conn = self.setup_db(db_config, engine)
hs = MockHomeserver(self.hs_config)
store = Store(Database(hs, db_config, engine), conn, hs)
yield store.db.runInteraction(
"%s_engine.check_database" % db_config.config["name"],
engine.check_database,
)
with make_conn(db_config, engine) as db_conn:
engine.check_database(
db_conn, allow_outdated_version=allow_outdated_version
)
prepare_database(db_conn, engine, config=self.hs_config)
store = Store(Database(hs, db_config, engine), db_conn, hs)
db_conn.commit()
return store
@ -502,8 +497,10 @@ class Porter(object):
@defer.inlineCallbacks
def run(self):
try:
self.sqlite_store = yield self.build_db_store(
DatabaseConnectionConfig("master-sqlite", self.sqlite_config)
# we allow people to port away from outdated versions of sqlite.
self.sqlite_store = self.build_db_store(
DatabaseConnectionConfig("master-sqlite", self.sqlite_config),
allow_outdated_version=True,
)
# Check if all background updates are done, abort if not.
@ -518,7 +515,7 @@ class Porter(object):
)
defer.returnValue(None)
self.postgres_store = yield self.build_db_store(
self.postgres_store = self.build_db_store(
self.hs_config.get_single_database()
)

View file

@ -17,13 +17,15 @@
"""Contains exceptions and error codes."""
import logging
from typing import Dict
from typing import Dict, List
from six import iteritems
from six.moves import http_client
from canonicaljson import json
from twisted.web import http
logger = logging.getLogger(__name__)
@ -80,6 +82,29 @@ class CodeMessageException(RuntimeError):
self.msg = msg
class RedirectException(CodeMessageException):
"""A pseudo-error indicating that we want to redirect the client to a different
location
Attributes:
cookies: a list of set-cookies values to add to the response. For example:
b"sessionId=a3fWa; Expires=Wed, 21 Oct 2015 07:28:00 GMT"
"""
def __init__(self, location: bytes, http_code: int = http.FOUND):
"""
Args:
location: the URI to redirect to
http_code: the HTTP response code
"""
msg = "Redirect to %s" % (location.decode("utf-8"),)
super().__init__(code=http_code, msg=msg)
self.location = location
self.cookies = [] # type: List[bytes]
class SynapseError(CodeMessageException):
"""A base exception type for matrix errors which have an errcode and error
message (as well as an HTTP status code).
@ -158,12 +183,6 @@ class UserDeactivatedError(SynapseError):
)
class RegistrationError(SynapseError):
"""An error raised when a registration event fails."""
pass
class FederationDeniedError(SynapseError):
"""An error raised when the server tries to federate with a server which
is not on its federation whitelist.

View file

@ -31,7 +31,7 @@ from prometheus_client import Gauge
from twisted.application import service
from twisted.internet import defer, reactor
from twisted.python.failure import Failure
from twisted.web.resource import EncodingResourceWrapper, NoResource
from twisted.web.resource import EncodingResourceWrapper, IResource, NoResource
from twisted.web.server import GzipEncoderFactory
from twisted.web.static import File
@ -109,7 +109,16 @@ class SynapseHomeServer(HomeServer):
for path, resmodule in additional_resources.items():
handler_cls, config = load_module(resmodule)
handler = handler_cls(config, module_api)
resources[path] = AdditionalResource(self, handler.handle_request)
if IResource.providedBy(handler):
resource = handler
elif hasattr(handler, "handle_request"):
resource = AdditionalResource(self, handler.handle_request)
else:
raise ConfigError(
"additional_resource %s does not implement a known interface"
% (resmodule["module"],)
)
resources[path] = resource
# try to find something useful to redirect '/' to
if WEB_CLIENT_PREFIX in resources:

View file

@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict
import six
from six import iteritems
@ -22,6 +23,7 @@ from six import iteritems
from canonicaljson import json
from prometheus_client import Counter
from twisted.internet import defer
from twisted.internet.abstract import isIPAddress
from twisted.python import failure
@ -41,7 +43,11 @@ from synapse.federation.federation_base import FederationBase, event_from_pdu_js
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
from synapse.http.endpoint import parse_server_name
from synapse.logging.context import nested_logging_context
from synapse.logging.context import (
make_deferred_yieldable,
nested_logging_context,
run_in_background,
)
from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace
from synapse.logging.utils import log_function
from synapse.replication.http.federation import (
@ -49,7 +55,7 @@ from synapse.replication.http.federation import (
ReplicationGetQueryRestServlet,
)
from synapse.types import get_domain_from_id
from synapse.util import glob_to_regex
from synapse.util import glob_to_regex, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
@ -160,6 +166,43 @@ class FederationServer(FederationBase):
)
return 400, response
# We process PDUs and EDUs in parallel. This is important as we don't
# want to block things like to device messages from reaching clients
# behind the potentially expensive handling of PDUs.
pdu_results, _ = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self._handle_pdus_in_txn, origin, transaction, request_time
),
run_in_background(self._handle_edus_in_txn, origin, transaction),
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
response = {"pdus": pdu_results}
logger.debug("Returning: %s", str(response))
await self.transaction_actions.set_response(origin, transaction, 200, response)
return 200, response
async def _handle_pdus_in_txn(
self, origin: str, transaction: Transaction, request_time: int
) -> Dict[str, dict]:
"""Process the PDUs in a received transaction.
Args:
origin: the server making the request
transaction: incoming transaction
request_time: timestamp that the HTTP request arrived at
Returns:
A map from event ID of a processed PDU to any errors we should
report back to the sending server.
"""
received_pdus_counter.inc(len(transaction.pdus))
origin_host, _ = parse_server_name(origin)
@ -250,20 +293,23 @@ class FederationServer(FederationBase):
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
)
if hasattr(transaction, "edus"):
for edu in (Edu(**x) for x in transaction.edus):
await self.received_edu(origin, edu.edu_type, edu.content)
return pdu_results
response = {"pdus": pdu_results}
async def _handle_edus_in_txn(self, origin: str, transaction: Transaction):
"""Process the EDUs in a received transaction.
"""
logger.debug("Returning: %s", str(response))
async def _process_edu(edu_dict):
received_edus_counter.inc()
await self.transaction_actions.set_response(origin, transaction, 200, response)
return 200, response
edu = Edu(**edu_dict)
await self.registry.on_edu(edu.edu_type, origin, edu.content)
async def received_edu(self, origin, edu_type, content):
received_edus_counter.inc()
await self.registry.on_edu(edu_type, origin, content)
await concurrently_execute(
_process_edu,
getattr(transaction, "edus", []),
TRANSACTION_CONCURRENCY_LIMIT,
)
async def on_context_state_request(self, origin, room_id, event_id):
origin_host, _ = parse_server_name(origin)

View file

@ -134,7 +134,7 @@ class AdminHandler(BaseHandler):
The returned value is that returned by `writer.finished()`.
"""
# Get all rooms the user is in or has been in
rooms = await self.store.get_rooms_for_user_where_membership_is(
rooms = await self.store.get_rooms_for_local_user_where_membership_is(
user_id,
membership_list=(
Membership.JOIN,

View file

@ -140,7 +140,7 @@ class DeactivateAccountHandler(BaseHandler):
user_id (str): The user ID to reject pending invites for.
"""
user = UserID.from_string(user_id)
pending_invites = await self.store.get_invited_rooms_for_user(user_id)
pending_invites = await self.store.get_invited_rooms_for_local_user(user_id)
for room in pending_invites:
try:

View file

@ -101,7 +101,7 @@ class InitialSyncHandler(BaseHandler):
if include_archived:
memberships.append(Membership.LEAVE)
room_list = await self.store.get_rooms_for_user_where_membership_is(
room_list = await self.store.get_rooms_for_local_user_where_membership_is(
user_id=user_id, membership_list=memberships
)

View file

@ -20,13 +20,7 @@ from twisted.internet import defer
from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, LoginType
from synapse.api.errors import (
AuthError,
Codes,
ConsentNotGivenError,
RegistrationError,
SynapseError,
)
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
from synapse.config.server import is_threepid_reserved
from synapse.http.servlet import assert_params_in_dict
from synapse.replication.http.login import RegisterDeviceReplicationServlet
@ -165,7 +159,7 @@ class RegistrationHandler(BaseHandler):
Returns:
Deferred[str]: user_id
Raises:
RegistrationError if there was a problem registering.
SynapseError if there was a problem registering.
"""
yield self.check_registration_ratelimit(address)
@ -174,7 +168,7 @@ class RegistrationHandler(BaseHandler):
if password:
password_hash = yield self._auth_handler.hash(password)
if localpart:
if localpart is not None:
yield self.check_username(localpart, guest_access_token=guest_access_token)
was_guest = guest_access_token is not None
@ -182,7 +176,7 @@ class RegistrationHandler(BaseHandler):
if not was_guest:
try:
int(localpart)
raise RegistrationError(
raise SynapseError(
400, "Numeric user IDs are reserved for guest users."
)
except ValueError:

View file

@ -690,7 +690,7 @@ class RoomMemberHandler(object):
@defer.inlineCallbacks
def _get_inviter(self, user_id, room_id):
invite = yield self.store.get_invite_for_user_in_room(
invite = yield self.store.get_invite_for_local_user_in_room(
user_id=user_id, room_id=room_id
)
if invite:

View file

@ -24,6 +24,7 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.http.servlet import parse_string
from synapse.module_api import ModuleApi
from synapse.rest.client.v1.login import SSOAuthHandler
from synapse.types import (
UserID,
@ -59,7 +60,8 @@ class SamlHandler:
# plugin to do custom mapping from saml response to mxid
self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
hs.config.saml2_user_mapping_provider_config
hs.config.saml2_user_mapping_provider_config,
ModuleApi(hs, hs.get_auth_handler()),
)
# identifier for the external_ids table
@ -112,10 +114,10 @@ class SamlHandler:
# the dict.
self.expire_sessions()
user_id = await self._map_saml_response_to_user(resp_bytes)
user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
self._sso_auth_handler.complete_sso_login(user_id, request, relay_state)
async def _map_saml_response_to_user(self, resp_bytes):
async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url):
try:
saml2_auth = self._saml_client.parse_authn_request_response(
resp_bytes,
@ -183,7 +185,7 @@ class SamlHandler:
# Map saml response to user attributes using the configured mapping provider
for i in range(1000):
attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
saml2_auth, i
saml2_auth, i, client_redirect_url=client_redirect_url,
)
logger.debug(
@ -216,6 +218,8 @@ class SamlHandler:
500, "Unable to generate a Matrix ID from the SAML response"
)
logger.info("Mapped SAML user to local part %s", localpart)
registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=displayname
)
@ -265,17 +269,21 @@ class SamlConfig(object):
class DefaultSamlMappingProvider(object):
__version__ = "0.0.1"
def __init__(self, parsed_config: SamlConfig):
def __init__(self, parsed_config: SamlConfig, module_api: ModuleApi):
"""The default SAML user mapping provider
Args:
parsed_config: Module configuration
module_api: module api proxy
"""
self._mxid_source_attribute = parsed_config.mxid_source_attribute
self._mxid_mapper = parsed_config.mxid_mapper
def saml_response_to_user_attributes(
self, saml_response: saml2.response.AuthnResponse, failures: int = 0,
self,
saml_response: saml2.response.AuthnResponse,
failures: int,
client_redirect_url: str,
) -> dict:
"""Maps some text from a SAML response to attributes of a new user
@ -285,6 +293,8 @@ class DefaultSamlMappingProvider(object):
failures: How many times a call to this function with this
saml_response has resulted in a failure
client_redirect_url: where the client wants to redirect to
Returns:
dict: A dict containing new user attributes. Possible keys:
* mxid_localpart (str): Required. The localpart of the user's mxid

View file

@ -179,7 +179,7 @@ class SearchHandler(BaseHandler):
search_filter = Filter(filter_dict)
# TODO: Search through left rooms too
rooms = yield self.store.get_rooms_for_user_where_membership_is(
rooms = yield self.store.get_rooms_for_local_user_where_membership_is(
user.to_string(),
membership_list=[Membership.JOIN],
# membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],

View file

@ -1662,7 +1662,7 @@ class SyncHandler(object):
Membership.BAN,
)
room_list = await self.store.get_rooms_for_user_where_membership_is(
room_list = await self.store.get_rooms_for_local_user_where_membership_is(
user_id=user_id, membership_list=membership_list
)

View file

@ -14,8 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import cgi
import collections
import html
import http.client
import logging
import types
@ -36,6 +36,7 @@ import synapse.metrics
from synapse.api.errors import (
CodeMessageException,
Codes,
RedirectException,
SynapseError,
UnrecognizedRequestError,
)
@ -153,14 +154,18 @@ def _return_html_error(f, request):
Args:
f (twisted.python.failure.Failure):
request (twisted.web.iweb.IRequest):
request (twisted.web.server.Request):
"""
if f.check(CodeMessageException):
cme = f.value
code = cme.code
msg = cme.msg
if isinstance(cme, SynapseError):
if isinstance(cme, RedirectException):
logger.info("%s redirect to %s", request, cme.location)
request.setHeader(b"location", cme.location)
request.cookies.extend(cme.cookies)
elif isinstance(cme, SynapseError):
logger.info("%s SynapseError: %s - %s", request, code, msg)
else:
logger.error(
@ -178,7 +183,7 @@ def _return_html_error(f, request):
exc_info=(f.type, f.value, f.getTracebackObject()),
)
body = HTML_ERROR_TEMPLATE.format(code=code, msg=cgi.escape(msg)).encode("utf-8")
body = HTML_ERROR_TEMPLATE.format(code=code, msg=html.escape(msg)).encode("utf-8")
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%i" % (len(body),))

View file

@ -88,7 +88,7 @@ class SynapseRequest(Request):
def get_redacted_uri(self):
uri = self.uri
if isinstance(uri, bytes):
uri = self.uri.decode("ascii")
uri = self.uri.decode("ascii", errors="replace")
return redact_uri(uri)
def get_method(self):

View file

@ -571,6 +571,9 @@ def run_in_background(f, *args, **kwargs):
yield or await on (for instance because you want to pass it to
deferred.gatherResults()).
If f returns a Coroutine object, it will be wrapped into a Deferred (which will have
the side effect of executing the coroutine).
Note that if you completely discard the result, you should make sure that
`f` doesn't raise any deferred exceptions, otherwise a scary-looking
CRITICAL error about an unhandled error will be logged without much

View file

@ -21,7 +21,7 @@ from synapse.storage import Storage
@defer.inlineCallbacks
def get_badge_count(store, user_id):
invites = yield store.get_invited_rooms_for_user(user_id)
invites = yield store.get_invited_rooms_for_local_user(user_id)
joins = yield store.get_rooms_for_user(user_id)
my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read")

View file

@ -16,6 +16,7 @@
import abc
import logging
import re
from typing import Dict, List, Tuple
from six import raise_from
from six.moves import urllib
@ -78,9 +79,8 @@ class ReplicationEndpoint(object):
__metaclass__ = abc.ABCMeta
NAME = abc.abstractproperty()
PATH_ARGS = abc.abstractproperty()
NAME = abc.abstractproperty() # type: str # type: ignore
PATH_ARGS = abc.abstractproperty() # type: Tuple[str, ...] # type: ignore
METHOD = "POST"
CACHE = True
RETRY_ON_TIMEOUT = True
@ -171,7 +171,7 @@ class ReplicationEndpoint(object):
# have a good idea that the request has either succeeded or failed on
# the master, and so whether we should clean up or not.
while True:
headers = {}
headers = {} # type: Dict[bytes, List[bytes]]
inject_active_span_byte_dict(headers, None, check_destination=False)
try:
result = yield request_func(uri, data, headers=headers)
@ -207,7 +207,7 @@ class ReplicationEndpoint(object):
method = self.METHOD
if self.CACHE:
handler = self._cached_handler
handler = self._cached_handler # type: ignore
url_args.append("txn_id")
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)

View file

@ -14,7 +14,7 @@
# limitations under the License.
import logging
from typing import Dict
from typing import Dict, Optional
import six
@ -41,7 +41,7 @@ class BaseSlavedStore(SQLBaseStore):
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker(
db_conn, "cache_invalidation_stream", "stream_id"
)
) # type: Optional[SlavedIdTracker]
else:
self._cache_id_gen = None
@ -62,7 +62,8 @@ class BaseSlavedStore(SQLBaseStore):
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "caches":
self._cache_id_gen.advance(token)
if self._cache_id_gen:
self._cache_id_gen.advance(token)
for row in rows:
if row.cache_func == CURRENT_STATE_CACHE_NAME:
room_id = row.keys[0]

View file

@ -152,7 +152,7 @@ class SlavedEventStore(
if etype == EventTypes.Member:
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
self.get_invited_rooms_for_user.invalidate((state_key,))
self.get_invited_rooms_for_local_user.invalidate((state_key,))
if relates_to:
self.get_relations_for_event.invalidate_many((relates_to,))

View file

@ -29,7 +29,7 @@ class SlavedPresenceStore(BaseSlavedStore):
self._presence_on_startup = self._get_active_presence(db_conn)
self.presence_stream_cache = self.presence_stream_cache = StreamChangeCache(
self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
)

View file

@ -16,7 +16,7 @@
"""
import logging
from typing import Dict
from typing import Dict, List, Optional
from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory
@ -28,6 +28,7 @@ from synapse.replication.tcp.protocol import (
)
from .commands import (
Command,
FederationAckCommand,
InvalidateCacheCommand,
RemovePusherCommand,
@ -89,15 +90,15 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
# Any pending commands to be sent once a new connection has been
# established
self.pending_commands = []
self.pending_commands = [] # type: List[Command]
# Map from string -> deferred, to wake up when receiveing a SYNC with
# the given string.
# Used for tests.
self.awaiting_syncs = {}
self.awaiting_syncs = {} # type: Dict[str, defer.Deferred]
# The factory used to create connections.
self.factory = None
self.factory = None # type: Optional[ReplicationClientFactory]
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
@ -235,4 +236,5 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
# We don't reset the delay any earlier as otherwise if there is a
# problem during start up we'll end up tight looping connecting to the
# server.
self.factory.resetDelay()
if self.factory:
self.factory.resetDelay()

View file

@ -20,15 +20,16 @@ allowed to be sent by which side.
import logging
import platform
from typing import Tuple, Type
if platform.python_implementation() == "PyPy":
import json
_json_encoder = json.JSONEncoder()
else:
import simplejson as json
import simplejson as json # type: ignore[no-redef] # noqa: F821
_json_encoder = json.JSONEncoder(namedtuple_as_object=False)
_json_encoder = json.JSONEncoder(namedtuple_as_object=False) # type: ignore[call-arg] # noqa: F821
logger = logging.getLogger(__name__)
@ -44,7 +45,7 @@ class Command(object):
The default implementation creates a command of form `<NAME> <data>`
"""
NAME = None
NAME = None # type: str
def __init__(self, data):
self.data = data
@ -386,25 +387,24 @@ class UserIpCommand(Command):
)
_COMMANDS = (
ServerCommand,
RdataCommand,
PositionCommand,
ErrorCommand,
PingCommand,
NameCommand,
ReplicateCommand,
UserSyncCommand,
FederationAckCommand,
SyncCommand,
RemovePusherCommand,
InvalidateCacheCommand,
UserIpCommand,
) # type: Tuple[Type[Command], ...]
# Map of command name to command type.
COMMAND_MAP = {
cmd.NAME: cmd
for cmd in (
ServerCommand,
RdataCommand,
PositionCommand,
ErrorCommand,
PingCommand,
NameCommand,
ReplicateCommand,
UserSyncCommand,
FederationAckCommand,
SyncCommand,
RemovePusherCommand,
InvalidateCacheCommand,
UserIpCommand,
)
}
COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS}
# The commands the server is allowed to send
VALID_SERVER_COMMANDS = (

View file

@ -53,6 +53,7 @@ import fcntl
import logging
import struct
from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Set, Tuple
from six import iteritems, iterkeys
@ -65,13 +66,11 @@ from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import Clock
from synapse.util.stringutils import random_string
from .commands import (
from synapse.replication.tcp.commands import (
COMMAND_MAP,
VALID_CLIENT_COMMANDS,
VALID_SERVER_COMMANDS,
Command,
ErrorCommand,
NameCommand,
PingCommand,
@ -82,6 +81,10 @@ from .commands import (
SyncCommand,
UserSyncCommand,
)
from synapse.types import Collection
from synapse.util import Clock
from synapse.util.stringutils import random_string
from .streams import STREAMS_MAP
connection_close_counter = Counter(
@ -124,8 +127,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
delimiter = b"\n"
VALID_INBOUND_COMMANDS = [] # Valid commands we expect to receive
VALID_OUTBOUND_COMMANDS = [] # Valid commans we can send
# Valid commands we expect to receive
VALID_INBOUND_COMMANDS = [] # type: Collection[str]
# Valid commands we can send
VALID_OUTBOUND_COMMANDS = [] # type: Collection[str]
max_line_buffer = 10000
@ -144,13 +150,13 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.conn_id = random_string(5) # To dedupe in case of name clashes.
# List of pending commands to send once we've established the connection
self.pending_commands = []
self.pending_commands = [] # type: List[Command]
# The LoopingCall for sending pings.
self._send_ping_loop = None
self.inbound_commands_counter = defaultdict(int)
self.outbound_commands_counter = defaultdict(int)
self.inbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
self.outbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
@ -409,14 +415,14 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.streamer = streamer
# The streams the client has subscribed to and is up to date with
self.replication_streams = set()
self.replication_streams = set() # type: Set[str]
# The streams the client is currently subscribing to.
self.connecting_streams = set()
self.connecting_streams = set() # type: Set[str]
# Map from stream name to list of updates to send once we've finished
# subscribing the client to the stream.
self.pending_rdata = {}
self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
def connectionMade(self):
self.send_command(ServerCommand(self.server_name))
@ -642,11 +648,11 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Set of stream names that have been subscribe to, but haven't yet
# caught up with. This is used to track when the client has been fully
# connected to the remote.
self.streams_connecting = set()
self.streams_connecting = set() # type: Set[str]
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
self.pending_batches = {}
self.pending_batches = {} # type: Dict[str, Any]
def connectionMade(self):
self.send_command(NameCommand(self.client_name))
@ -766,7 +772,7 @@ def transport_kernel_read_buffer_size(protocol, read=True):
op = SIOCINQ
else:
op = SIOCOUTQ
size = struct.unpack("I", fcntl.ioctl(fileno, op, "\0\0\0\0"))[0]
size = struct.unpack("I", fcntl.ioctl(fileno, op, b"\0\0\0\0"))[0]
return size
return 0

View file

@ -17,6 +17,7 @@
import logging
import random
from typing import List
from six import itervalues
@ -79,7 +80,7 @@ class ReplicationStreamer(object):
self._replication_torture_level = hs.config.replication_torture_level
# Current connections.
self.connections = []
self.connections = [] # type: List[ServerReplicationStreamProtocol]
LaterGauge(
"synapse_replication_tcp_resource_total_connections",

View file

@ -14,10 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
from collections import namedtuple
from typing import Any
from twisted.internet import defer
@ -104,8 +104,9 @@ class Stream(object):
time it was called up until the point `advance_current_token` was called.
"""
NAME = None # The name of the stream
ROW_TYPE = None # The type of the row. Used by the default impl of parse_row.
NAME = None # type: str # The name of the stream
# The type of the row. Used by the default impl of parse_row.
ROW_TYPE = None # type: Any
_LIMITED = True # Whether the update function takes a limit
@classmethod
@ -231,8 +232,8 @@ class BackfillStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_current_backfill_token
self.update_function = store.get_all_new_backfill_event_rows
self.current_token = store.get_current_backfill_token # type: ignore
self.update_function = store.get_all_new_backfill_event_rows # type: ignore
super(BackfillStream, self).__init__(hs)
@ -246,8 +247,8 @@ class PresenceStream(Stream):
store = hs.get_datastore()
presence_handler = hs.get_presence_handler()
self.current_token = store.get_current_presence_token
self.update_function = presence_handler.get_all_presence_updates
self.current_token = store.get_current_presence_token # type: ignore
self.update_function = presence_handler.get_all_presence_updates # type: ignore
super(PresenceStream, self).__init__(hs)
@ -260,8 +261,8 @@ class TypingStream(Stream):
def __init__(self, hs):
typing_handler = hs.get_typing_handler()
self.current_token = typing_handler.get_current_token
self.update_function = typing_handler.get_all_typing_updates
self.current_token = typing_handler.get_current_token # type: ignore
self.update_function = typing_handler.get_all_typing_updates # type: ignore
super(TypingStream, self).__init__(hs)
@ -273,8 +274,8 @@ class ReceiptsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_max_receipt_stream_id
self.update_function = store.get_all_updated_receipts
self.current_token = store.get_max_receipt_stream_id # type: ignore
self.update_function = store.get_all_updated_receipts # type: ignore
super(ReceiptsStream, self).__init__(hs)
@ -310,8 +311,8 @@ class PushersStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_pushers_stream_token
self.update_function = store.get_all_updated_pushers_rows
self.current_token = store.get_pushers_stream_token # type: ignore
self.update_function = store.get_all_updated_pushers_rows # type: ignore
super(PushersStream, self).__init__(hs)
@ -327,8 +328,8 @@ class CachesStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_cache_stream_token
self.update_function = store.get_all_updated_caches
self.current_token = store.get_cache_stream_token # type: ignore
self.update_function = store.get_all_updated_caches # type: ignore
super(CachesStream, self).__init__(hs)
@ -343,8 +344,8 @@ class PublicRoomsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_current_public_room_stream_id
self.update_function = store.get_all_new_public_rooms
self.current_token = store.get_current_public_room_stream_id # type: ignore
self.update_function = store.get_all_new_public_rooms # type: ignore
super(PublicRoomsStream, self).__init__(hs)
@ -360,8 +361,8 @@ class DeviceListsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_device_stream_token
self.update_function = store.get_all_device_list_changes_for_remotes
self.current_token = store.get_device_stream_token # type: ignore
self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore
super(DeviceListsStream, self).__init__(hs)
@ -376,8 +377,8 @@ class ToDeviceStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_to_device_stream_token
self.update_function = store.get_all_new_device_messages
self.current_token = store.get_to_device_stream_token # type: ignore
self.update_function = store.get_all_new_device_messages # type: ignore
super(ToDeviceStream, self).__init__(hs)
@ -392,8 +393,8 @@ class TagAccountDataStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_max_account_data_stream_id
self.update_function = store.get_all_updated_tags
self.current_token = store.get_max_account_data_stream_id # type: ignore
self.update_function = store.get_all_updated_tags # type: ignore
super(TagAccountDataStream, self).__init__(hs)
@ -408,7 +409,7 @@ class AccountDataStream(Stream):
def __init__(self, hs):
self.store = hs.get_datastore()
self.current_token = self.store.get_max_account_data_stream_id
self.current_token = self.store.get_max_account_data_stream_id # type: ignore
super(AccountDataStream, self).__init__(hs)
@ -434,8 +435,8 @@ class GroupServerStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_group_stream_token
self.update_function = store.get_all_groups_changes
self.current_token = store.get_group_stream_token # type: ignore
self.update_function = store.get_all_groups_changes # type: ignore
super(GroupServerStream, self).__init__(hs)
@ -451,7 +452,7 @@ class UserSignatureStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_device_stream_token
self.update_function = store.get_all_user_signature_changes_for_remotes
self.current_token = store.get_device_stream_token # type: ignore
self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore
super(UserSignatureStream, self).__init__(hs)

View file

@ -13,7 +13,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import heapq
from typing import Tuple, Type
import attr
@ -63,7 +65,8 @@ class BaseEventsStreamRow(object):
Specifies how to identify, serialize and deserialize the different types.
"""
TypeId = None # Unique string that ids the type. Must be overriden in sub classes.
# Unique string that ids the type. Must be overriden in sub classes.
TypeId = None # type: str
@classmethod
def from_data(cls, data):
@ -99,9 +102,12 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow):
event_id = attr.ib() # str, optional
TypeToRow = {
Row.TypeId: Row for Row in (EventsStreamEventRow, EventsStreamCurrentStateRow)
}
_EventRows = (
EventsStreamEventRow,
EventsStreamCurrentStateRow,
) # type: Tuple[Type[BaseEventsStreamRow], ...]
TypeToRow = {Row.TypeId: Row for Row in _EventRows}
class EventsStream(Stream):
@ -112,7 +118,7 @@ class EventsStream(Stream):
def __init__(self, hs):
self._store = hs.get_datastore()
self.current_token = self._store.get_current_events_token
self.current_token = self._store.get_current_events_token # type: ignore
super(EventsStream, self).__init__(hs)

View file

@ -37,7 +37,7 @@ class FederationStream(Stream):
def __init__(self, hs):
federation_sender = hs.get_federation_sender()
self.current_token = federation_sender.get_current_token
self.update_function = federation_sender.get_replication_rows
self.current_token = federation_sender.get_current_token # type: ignore
self.update_function = federation_sender.get_replication_rows # type: ignore
super(FederationStream, self).__init__(hs)

View file

@ -32,16 +32,24 @@ class QuarantineMediaInRoom(RestServlet):
this server.
"""
PATTERNS = historical_admin_path_patterns("/quarantine_media/(?P<room_id>[^/]+)")
PATTERNS = (
historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media/quarantine")
+
# This path kept around for legacy reasons
historical_admin_path_patterns("/quarantine_media/(?P<room_id>![^/]+)")
)
def __init__(self, hs):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request, room_id):
async def on_POST(self, request, room_id: str):
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
logging.info("Quarantining room: %s", room_id)
# Quarantine all media in this room
num_quarantined = await self.store.quarantine_media_ids_in_room(
room_id, requester.user.to_string()
)
@ -49,6 +57,60 @@ class QuarantineMediaInRoom(RestServlet):
return 200, {"num_quarantined": num_quarantined}
class QuarantineMediaByUser(RestServlet):
"""Quarantines all local media by a given user so that no one can download it via
this server.
"""
PATTERNS = historical_admin_path_patterns(
"/user/(?P<user_id>[^/]+)/media/quarantine"
)
def __init__(self, hs):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request, user_id: str):
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
logging.info("Quarantining local media by user: %s", user_id)
# Quarantine all media this user has uploaded
num_quarantined = await self.store.quarantine_media_ids_by_user(
user_id, requester.user.to_string()
)
return 200, {"num_quarantined": num_quarantined}
class QuarantineMediaByID(RestServlet):
"""Quarantines local or remote media by a given ID so that no one can download
it via this server.
"""
PATTERNS = historical_admin_path_patterns(
"/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
)
def __init__(self, hs):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request, server_name: str, media_id: str):
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
logging.info("Quarantining local media by ID: %s/%s", server_name, media_id)
# Quarantine this media id
await self.store.quarantine_media_by_id(
server_name, media_id, requester.user.to_string()
)
return 200, {}
class ListMediaInRoom(RestServlet):
"""Lists all of the media in a given room.
"""
@ -94,4 +156,6 @@ def register_servlets_for_media_repo(hs, http_server):
"""
PurgeMediaCacheRestServlet(hs).register(http_server)
QuarantineMediaInRoom(hs).register(http_server)
QuarantineMediaByID(hs).register(http_server)
QuarantineMediaByUser(hs).register(http_server)
ListMediaInRoom(hs).register(http_server)

View file

@ -105,7 +105,7 @@ class ServerNoticesManager(object):
assert self._is_mine_id(user_id), "Cannot send server notices to remote users"
rooms = yield self._store.get_rooms_for_user_where_membership_is(
rooms = yield self._store.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE, Membership.JOIN]
)
system_mxid = self._config.server_notices_mxid

View file

@ -47,7 +47,7 @@ class DataStores(object):
with make_conn(database_config, engine) as db_conn:
logger.info("Preparing database %r...", db_name)
engine.check_database(db_conn.cursor())
engine.check_database(db_conn)
prepare_database(
db_conn, engine, hs.config, data_stores=database_config.data_stores,
)

View file

@ -128,6 +128,7 @@ class EventsStore(
hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
@defer.inlineCallbacks
def _read_forward_extremities(self):
@ -547,6 +548,34 @@ class EventsStore(
],
)
# Note: Do we really want to delete rows here (that we do not
# subsequently reinsert below)? While technically correct it means
# we have no record of the fact the user *was* a member of the
# room but got, say, state reset out of it.
if to_delete or to_insert:
txn.executemany(
"DELETE FROM local_current_membership"
" WHERE room_id = ? AND user_id = ?",
(
(room_id, state_key)
for etype, state_key in itertools.chain(to_delete, to_insert)
if etype == EventTypes.Member and self.is_mine_id(state_key)
),
)
if to_insert:
txn.executemany(
"""INSERT INTO local_current_membership
(room_id, user_id, event_id, membership)
VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
""",
[
(room_id, key[1], ev_id, ev_id)
for key, ev_id in to_insert.items()
if key[0] == EventTypes.Member and self.is_mine_id(key[1])
],
)
txn.call_after(
self._curr_state_delta_stream_cache.entity_has_changed,
room_id,
@ -1724,6 +1753,7 @@ class EventsStore(
"local_invites",
"room_account_data",
"room_tags",
"local_current_membership",
):
logger.info("[purge] removing %s from %s", room_id, table)
txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))

View file

@ -18,7 +18,7 @@ import collections
import logging
import re
from abc import abstractmethod
from typing import Optional, Tuple
from typing import List, Optional, Tuple
from six import integer_types
@ -399,6 +399,8 @@ class RoomWorkerStore(SQLBaseStore):
the associated media
"""
logger.info("Quarantining media in room: %s", room_id)
def _quarantine_media_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
total_media_quarantined = 0
@ -494,6 +496,118 @@ class RoomWorkerStore(SQLBaseStore):
return local_media_mxcs, remote_media_mxcs
def quarantine_media_by_id(
self, server_name: str, media_id: str, quarantined_by: str,
):
"""quarantines a single local or remote media id
Args:
server_name: The name of the server that holds this media
media_id: The ID of the media to be quarantined
quarantined_by: The user ID that initiated the quarantine request
"""
logger.info("Quarantining media: %s/%s", server_name, media_id)
is_local = server_name == self.config.server_name
def _quarantine_media_by_id_txn(txn):
local_mxcs = [media_id] if is_local else []
remote_mxcs = [(server_name, media_id)] if not is_local else []
return self._quarantine_media_txn(
txn, local_mxcs, remote_mxcs, quarantined_by
)
return self.db.runInteraction(
"quarantine_media_by_user", _quarantine_media_by_id_txn
)
def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str):
"""quarantines all local media associated with a single user
Args:
user_id: The ID of the user to quarantine media of
quarantined_by: The ID of the user who made the quarantine request
"""
def _quarantine_media_by_user_txn(txn):
local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
return self.db.runInteraction(
"quarantine_media_by_user", _quarantine_media_by_user_txn
)
def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True):
"""Retrieves local media IDs by a given user
Args:
txn (cursor)
user_id: The ID of the user to retrieve media IDs of
Returns:
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
"""
# Local media
sql = """
SELECT media_id
FROM local_media_repository
WHERE user_id = ?
"""
if filter_quarantined:
sql += "AND quarantined_by IS NULL"
txn.execute(sql, (user_id,))
local_media_ids = [row[0] for row in txn]
# TODO: Figure out all remote media a user has referenced in a message
return local_media_ids
def _quarantine_media_txn(
self,
txn,
local_mxcs: List[str],
remote_mxcs: List[Tuple[str, str]],
quarantined_by: str,
) -> int:
"""Quarantine local and remote media items
Args:
txn (cursor)
local_mxcs: A list of local mxc URLs
remote_mxcs: A list of (remote server, media id) tuples representing
remote mxc URLs
quarantined_by: The ID of the user who initiated the quarantine request
Returns:
The total number of media items quarantined
"""
total_media_quarantined = 0
# Update all the tables to set the quarantined_by flag
txn.executemany(
"""
UPDATE local_media_repository
SET quarantined_by = ?
WHERE media_id = ?
""",
((quarantined_by, media_id) for media_id in local_mxcs),
)
txn.executemany(
"""
UPDATE remote_media_cache
SET quarantined_by = ?
WHERE media_origin = ? AND media_id = ?
""",
((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs),
)
total_media_quarantined += len(local_mxcs)
total_media_quarantined += len(remote_mxcs)
return total_media_quarantined
class RoomBackgroundUpdateStore(SQLBaseStore):
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"

View file

@ -297,19 +297,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return {row[0]: row[1] for row in txn}
@cached()
def get_invited_rooms_for_user(self, user_id):
""" Get all the rooms the user is invited to
def get_invited_rooms_for_local_user(self, user_id):
""" Get all the rooms the *local* user is invited to
Args:
user_id (str): The user ID.
Returns:
A deferred list of RoomsForUser.
"""
return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE])
return self.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE]
)
@defer.inlineCallbacks
def get_invite_for_user_in_room(self, user_id, room_id):
"""Gets the invite for the given user and room
def get_invite_for_local_user_in_room(self, user_id, room_id):
"""Gets the invite for the given *local* user and room
Args:
user_id (str)
@ -319,15 +322,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
Deferred: Resolves to either a RoomsForUser or None if no invite was
found.
"""
invites = yield self.get_invited_rooms_for_user(user_id)
invites = yield self.get_invited_rooms_for_local_user(user_id)
for invite in invites:
if invite.room_id == room_id:
return invite
return None
@defer.inlineCallbacks
def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this user where the membership for this user
def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
Filters out forgotten rooms.
@ -344,8 +347,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return defer.succeed(None)
rooms = yield self.db.runInteraction(
"get_rooms_for_user_where_membership_is",
self._get_rooms_for_user_where_membership_is_txn,
"get_rooms_for_local_user_where_membership_is",
self._get_rooms_for_local_user_where_membership_is_txn,
user_id,
membership_list,
)
@ -354,76 +357,42 @@ class RoomMemberWorkerStore(EventsWorkerStore):
forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id)
return [room for room in rooms if room.room_id not in forgotten_rooms]
def _get_rooms_for_user_where_membership_is_txn(
def _get_rooms_for_local_user_where_membership_is_txn(
self, txn, user_id, membership_list
):
do_invite = Membership.INVITE in membership_list
membership_list = [m for m in membership_list if m != Membership.INVITE]
results = []
if membership_list:
if self._current_state_events_membership_up_to_date:
clause, args = make_in_list_sql_clause(
self.database_engine, "c.membership", membership_list
)
sql = """
SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
FROM current_state_events AS c
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
AND state_key = ?
AND %s
""" % (
clause,
)
else:
clause, args = make_in_list_sql_clause(
self.database_engine, "m.membership", membership_list
)
sql = """
SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering
FROM current_state_events AS c
INNER JOIN room_memberships AS m USING (room_id, event_id)
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
AND state_key = ?
AND %s
""" % (
clause,
)
txn.execute(sql, (user_id, *args))
results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]
if do_invite:
sql = (
"SELECT i.room_id, inviter, i.event_id, e.stream_ordering"
" FROM local_invites as i"
" INNER JOIN events as e USING (event_id)"
" WHERE invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
# Paranoia check.
if not self.hs.is_mine_id(user_id):
raise Exception(
"Cannot call 'get_rooms_for_local_user_where_membership_is' on non-local user %r"
% (user_id,),
)
txn.execute(sql, (user_id,))
results.extend(
RoomsForUser(
room_id=r["room_id"],
sender=r["inviter"],
event_id=r["event_id"],
stream_ordering=r["stream_ordering"],
membership=Membership.INVITE,
)
for r in self.db.cursor_to_dict(txn)
)
clause, args = make_in_list_sql_clause(
self.database_engine, "c.membership", membership_list
)
sql = """
SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
FROM local_current_membership AS c
INNER JOIN events AS e USING (room_id, event_id)
WHERE
user_id = ?
AND %s
""" % (
clause,
)
txn.execute(sql, (user_id, *args))
results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]
return results
@cachedInlineCallbacks(max_entries=500000, iterable=True)
@cached(max_entries=500000, iterable=True)
def get_rooms_for_user_with_stream_ordering(self, user_id):
"""Returns a set of room_ids the user is currently joined to
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently
participating in.
Args:
user_id (str)
@ -433,17 +402,49 @@ class RoomMemberWorkerStore(EventsWorkerStore):
the rooms the user is in currently, along with the stream ordering
of the most recent join for that user and room.
"""
rooms = yield self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN]
)
return frozenset(
GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
for r in rooms
return self.db.runInteraction(
"get_rooms_for_user_with_stream_ordering",
self._get_rooms_for_user_with_stream_ordering_txn,
user_id,
)
def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id):
# We use `current_state_events` here and not `local_current_membership`
# as a) this gets called with remote users and b) this only gets called
# for rooms the server is participating in.
if self._current_state_events_membership_up_to_date:
sql = """
SELECT room_id, e.stream_ordering
FROM current_state_events AS c
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
AND state_key = ?
AND c.membership = ?
"""
else:
sql = """
SELECT room_id, e.stream_ordering
FROM current_state_events AS c
INNER JOIN room_memberships AS m USING (room_id, event_id)
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
AND state_key = ?
AND m.membership = ?
"""
txn.execute(sql, (user_id, Membership.JOIN))
results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
return results
@defer.inlineCallbacks
def get_rooms_for_user(self, user_id, on_invalidate=None):
"""Returns a set of room_ids the user is currently joined to
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently
participating in.
"""
rooms = yield self.get_rooms_for_user_with_stream_ordering(
user_id, on_invalidate=on_invalidate
@ -1022,7 +1023,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
event.internal_metadata.stream_ordering,
)
txn.call_after(
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
self.get_invited_rooms_for_local_user.invalidate, (event.state_key,)
)
# We update the local_invites table only if the event is "current",
@ -1064,6 +1065,27 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
),
)
# We also update the `local_current_membership` table with
# latest invite info. This will usually get updated by the
# `current_state_events` handling, unless its an outlier.
if event.internal_metadata.is_outlier():
# This should only happen for out of band memberships, so
# we add a paranoia check.
assert event.internal_metadata.is_out_of_band_membership()
self.db.simple_upsert_txn(
txn,
table="local_current_membership",
keyvalues={
"room_id": event.room_id,
"user_id": event.state_key,
},
values={
"event_id": event.event_id,
"membership": event.membership,
},
)
@defer.inlineCallbacks
def locally_reject_invite(self, user_id, room_id):
sql = (
@ -1075,6 +1097,15 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def f(txn, stream_ordering):
txn.execute(sql, (stream_ordering, True, room_id, user_id))
# We also clear this entry from `local_current_membership`.
# Ideally we'd point to a leave event, but we don't have one, so
# nevermind.
self.db.simple_delete_txn(
txn,
table="local_current_membership",
keyvalues={"room_id": room_id, "user_id": user_id},
)
with self._stream_id_gen.get_next() as stream_ordering:
yield self.db.runInteraction("locally_reject_invite", f, stream_ordering)

View file

@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-
# Copyright 2020 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# We create a new table called `local_current_membership` that stores the latest
# membership state of local users in rooms, which helps track leaves/bans/etc
# even if the server has left the room (and so has deleted the room from
# `current_state_events`). This will also include outstanding invites for local
# users for rooms the server isn't in.
#
# If the server isn't and hasn't been in the room then it will only include
# outsstanding invites, and not e.g. pre-emptive bans of local users.
#
# If the server later rejoins a room `local_current_membership` can simply be
# replaced with the new current state of the room (which results in the
# equivalent behaviour as if the server had remained in the room).
def run_upgrade(cur, database_engine, config, *args, **kwargs):
# We need to do the insert in `run_upgrade` section as we don't have access
# to `config` in `run_create`.
# This upgrade may take a bit of time for large servers (e.g. one minute for
# matrix.org) but means we avoid a lots of book keeping required to do it as
# a background update.
# We check if the `current_state_events.membership` is up to date by
# checking if the relevant background update has finished. If it has
# finished we can avoid doing a join against `room_memberships`, which
# speesd things up.
cur.execute(
"""SELECT 1 FROM background_updates
WHERE update_name = 'current_state_events_membership'
"""
)
current_state_membership_up_to_date = not bool(cur.fetchone())
# Cheekily drop and recreate indices, as that is faster.
cur.execute("DROP INDEX local_current_membership_idx")
cur.execute("DROP INDEX local_current_membership_room_idx")
if current_state_membership_up_to_date:
sql = """
INSERT INTO local_current_membership (room_id, user_id, event_id, membership)
SELECT c.room_id, state_key AS user_id, event_id, c.membership
FROM current_state_events AS c
WHERE type = 'm.room.member' AND c.membership IS NOT NULL AND state_key like '%' || ?
"""
else:
# We can't rely on the membership column, so we need to join against
# `room_memberships`.
sql = """
INSERT INTO local_current_membership (room_id, user_id, event_id, membership)
SELECT c.room_id, state_key AS user_id, event_id, r.membership
FROM current_state_events AS c
INNER JOIN room_memberships AS r USING (event_id)
WHERE type = 'm.room.member' and state_key like '%' || ?
"""
cur.execute(sql, (config.server_name,))
cur.execute(
"CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)"
)
cur.execute(
"CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)"
)
def run_create(cur, database_engine, *args, **kwargs):
cur.execute(
"""
CREATE TABLE local_current_membership (
room_id TEXT NOT NULL,
user_id TEXT NOT NULL,
event_id TEXT NOT NULL,
membership TEXT NOT NULL
)"""
)
cur.execute(
"CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)"
)
cur.execute(
"CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)"
)

View file

@ -32,20 +32,7 @@ class PostgresEngine(object):
self.synchronous_commit = database_config.get("synchronous_commit", True)
self._version = None # unknown as yet
def check_database(self, txn):
txn.execute("SHOW SERVER_ENCODING")
rows = txn.fetchall()
if rows and rows[0][0] != "UTF8":
raise IncorrectDatabaseSetup(
"Database has incorrect encoding: '%s' instead of 'UTF8'\n"
"See docs/postgres.rst for more information." % (rows[0][0],)
)
def convert_param_style(self, sql):
return sql.replace("?", "%s")
def on_new_connection(self, db_conn):
def check_database(self, db_conn, allow_outdated_version: bool = False):
# Get the version of PostgreSQL that we're using. As per the psycopg2
# docs: The number is formed by converting the major, minor, and
# revision numbers into two-decimal-digit numbers and appending them
@ -53,9 +40,22 @@ class PostgresEngine(object):
self._version = db_conn.server_version
# Are we on a supported PostgreSQL version?
if self._version < 90500:
if not allow_outdated_version and self._version < 90500:
raise RuntimeError("Synapse requires PostgreSQL 9.5+ or above.")
with db_conn.cursor() as txn:
txn.execute("SHOW SERVER_ENCODING")
rows = txn.fetchall()
if rows and rows[0][0] != "UTF8":
raise IncorrectDatabaseSetup(
"Database has incorrect encoding: '%s' instead of 'UTF8'\n"
"See docs/postgres.rst for more information." % (rows[0][0],)
)
def convert_param_style(self, sql):
return sql.replace("?", "%s")
def on_new_connection(self, db_conn):
db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
@ -119,8 +119,8 @@ class PostgresEngine(object):
Returns:
string
"""
# note that this is a bit of a hack because it relies on on_new_connection
# having been called at least once. Still, that should be a safe bet here.
# note that this is a bit of a hack because it relies on check_database
# having been called. Still, that should be a safe bet here.
numver = self._version
assert numver is not None

View file

@ -53,8 +53,11 @@ class Sqlite3Engine(object):
"""
return False
def check_database(self, txn):
pass
def check_database(self, db_conn, allow_outdated_version: bool = False):
if not allow_outdated_version:
version = self.module.sqlite_version_info
if version < (3, 11, 0):
raise RuntimeError("Synapse requires sqlite 3.11 or above.")
def convert_param_style(self, sql):
return sql

View file

@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 56
SCHEMA_VERSION = 57
dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -269,8 +269,6 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
one will be randomly generated.
Returns:
A tuple of (user_id, access_token).
Raises:
RegistrationError if there was a problem registering.
"""
if localpart is None:
raise SynapseError(400, "Request must include user id")

View file

@ -32,8 +32,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
def test_wait_for_sync_for_user_auth_blocking(self):
user_id1 = "@user1:server"
user_id2 = "@user2:server"
user_id1 = "@user1:test"
user_id2 = "@user2:test"
sync_config = self._generate_sync_config(user_id1)
self.reactor.advance(100) # So we get not 0 time

View file

@ -115,13 +115,13 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
def test_invites(self):
self.persist(type="m.room.create", key="", creator=USER_ID)
self.check("get_invited_rooms_for_user", [USER_ID_2], [])
self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
self.replicate()
self.check(
"get_invited_rooms_for_user",
"get_invited_rooms_for_local_user",
[USER_ID_2],
[
RoomsForUser(

View file

@ -14,11 +14,17 @@
# limitations under the License.
import json
import os
import urllib.parse
from binascii import unhexlify
from mock import Mock
from twisted.internet.defer import Deferred
import synapse.rest.admin
from synapse.http.server import JsonResource
from synapse.logging.context import make_deferred_yieldable
from synapse.rest.admin import VersionServlet
from synapse.rest.client.v1 import events, login, room
from synapse.rest.client.v2_alpha import groups
@ -346,3 +352,338 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
test_purge_room.skip = "Disabled because it's currently broken"
class QuarantineMediaTestCase(unittest.HomeserverTestCase):
"""Test /quarantine_media admin API.
"""
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo,
login.register_servlets,
room.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.hs = hs
# Allow for uploading and downloading to/from the media repo
self.media_repo = hs.get_media_repository_resource()
self.download_resource = self.media_repo.children[b"download"]
self.upload_resource = self.media_repo.children[b"upload"]
self.image_data = unhexlify(
b"89504e470d0a1a0a0000000d4948445200000001000000010806"
b"0000001f15c4890000000a49444154789c63000100000500010d"
b"0a2db40000000049454e44ae426082"
)
def make_homeserver(self, reactor, clock):
self.fetches = []
def get_file(destination, path, output_stream, args=None, max_size=None):
"""
Returns tuple[int,dict,str,int] of file length, response headers,
absolute URI, and response code.
"""
def write_to(r):
data, response = r
output_stream.write(data)
return response
d = Deferred()
d.addCallback(write_to)
self.fetches.append((d, destination, path, args))
return make_deferred_yieldable(d)
client = Mock()
client.get_file = get_file
self.storage_path = self.mktemp()
self.media_store_path = self.mktemp()
os.mkdir(self.storage_path)
os.mkdir(self.media_store_path)
config = self.default_config()
config["media_store_path"] = self.media_store_path
config["thumbnail_requirements"] = {}
config["max_image_pixels"] = 2000000
provider_config = {
"module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
"store_local": True,
"store_synchronous": False,
"store_remote": True,
"config": {"directory": self.storage_path},
}
config["media_storage_providers"] = [provider_config]
hs = self.setup_test_homeserver(config=config, http_client=client)
return hs
def test_quarantine_media_requires_admin(self):
self.register_user("nonadmin", "pass", admin=False)
non_admin_user_tok = self.login("nonadmin", "pass")
# Attempt quarantine media APIs as non-admin
url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345"
request, channel = self.make_request(
"POST", url.encode("ascii"), access_token=non_admin_user_tok,
)
self.render(request)
# Expect a forbidden error
self.assertEqual(
403,
int(channel.result["code"]),
msg="Expected forbidden on quarantining media as a non-admin",
)
# And the roomID/userID endpoint
url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine"
request, channel = self.make_request(
"POST", url.encode("ascii"), access_token=non_admin_user_tok,
)
self.render(request)
# Expect a forbidden error
self.assertEqual(
403,
int(channel.result["code"]),
msg="Expected forbidden on quarantining media as a non-admin",
)
def test_quarantine_media_by_id(self):
self.register_user("id_admin", "pass", admin=True)
admin_user_tok = self.login("id_admin", "pass")
self.register_user("id_nonadmin", "pass", admin=False)
non_admin_user_tok = self.login("id_nonadmin", "pass")
# Upload some media into the room
response = self.helper.upload_media(
self.upload_resource, self.image_data, tok=admin_user_tok
)
# Extract media ID from the response
server_name_and_media_id = response["content_uri"][
6:
] # Cut off the 'mxc://' bit
server_name, media_id = server_name_and_media_id.split("/")
# Attempt to access the media
request, channel = self.make_request(
"GET",
server_name_and_media_id,
shorthand=False,
access_token=non_admin_user_tok,
)
request.render(self.download_resource)
self.pump(1.0)
# Should be successful
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
# Quarantine the media
url = "/_synapse/admin/v1/media/quarantine/%s/%s" % (
urllib.parse.quote(server_name),
urllib.parse.quote(media_id),
)
request, channel = self.make_request("POST", url, access_token=admin_user_tok,)
self.render(request)
self.pump(1.0)
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
# Attempt to access the media
request, channel = self.make_request(
"GET",
server_name_and_media_id,
shorthand=False,
access_token=admin_user_tok,
)
request.render(self.download_resource)
self.pump(1.0)
# Should be quarantined
self.assertEqual(
404,
int(channel.code),
msg=(
"Expected to receive a 404 on accessing quarantined media: %s"
% server_name_and_media_id
),
)
def test_quarantine_all_media_in_room(self):
self.register_user("room_admin", "pass", admin=True)
admin_user_tok = self.login("room_admin", "pass")
non_admin_user = self.register_user("room_nonadmin", "pass", admin=False)
non_admin_user_tok = self.login("room_nonadmin", "pass")
room_id = self.helper.create_room_as(non_admin_user, tok=admin_user_tok)
self.helper.join(room_id, non_admin_user, tok=non_admin_user_tok)
# Upload some media
response_1 = self.helper.upload_media(
self.upload_resource, self.image_data, tok=non_admin_user_tok
)
response_2 = self.helper.upload_media(
self.upload_resource, self.image_data, tok=non_admin_user_tok
)
# Extract mxcs
mxc_1 = response_1["content_uri"]
mxc_2 = response_2["content_uri"]
# Send it into the room
self.helper.send_event(
room_id,
"m.room.message",
content={"body": "image-1", "msgtype": "m.image", "url": mxc_1},
txn_id="111",
tok=non_admin_user_tok,
)
self.helper.send_event(
room_id,
"m.room.message",
content={"body": "image-2", "msgtype": "m.image", "url": mxc_2},
txn_id="222",
tok=non_admin_user_tok,
)
# Quarantine all media in the room
url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote(
room_id
)
request, channel = self.make_request("POST", url, access_token=admin_user_tok,)
self.render(request)
self.pump(1.0)
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
self.assertEqual(
json.loads(channel.result["body"].decode("utf-8")),
{"num_quarantined": 2},
"Expected 2 quarantined items",
)
# Convert mxc URLs to server/media_id strings
server_and_media_id_1 = mxc_1[6:]
server_and_media_id_2 = mxc_2[6:]
# Test that we cannot download any of the media anymore
request, channel = self.make_request(
"GET",
server_and_media_id_1,
shorthand=False,
access_token=non_admin_user_tok,
)
request.render(self.download_resource)
self.pump(1.0)
# Should be quarantined
self.assertEqual(
404,
int(channel.code),
msg=(
"Expected to receive a 404 on accessing quarantined media: %s"
% server_and_media_id_1
),
)
request, channel = self.make_request(
"GET",
server_and_media_id_2,
shorthand=False,
access_token=non_admin_user_tok,
)
request.render(self.download_resource)
self.pump(1.0)
# Should be quarantined
self.assertEqual(
404,
int(channel.code),
msg=(
"Expected to receive a 404 on accessing quarantined media: %s"
% server_and_media_id_2
),
)
def test_quarantine_all_media_by_user(self):
self.register_user("user_admin", "pass", admin=True)
admin_user_tok = self.login("user_admin", "pass")
non_admin_user = self.register_user("user_nonadmin", "pass", admin=False)
non_admin_user_tok = self.login("user_nonadmin", "pass")
# Upload some media
response_1 = self.helper.upload_media(
self.upload_resource, self.image_data, tok=non_admin_user_tok
)
response_2 = self.helper.upload_media(
self.upload_resource, self.image_data, tok=non_admin_user_tok
)
# Extract media IDs
server_and_media_id_1 = response_1["content_uri"][6:]
server_and_media_id_2 = response_2["content_uri"][6:]
# Quarantine all media by this user
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
non_admin_user
)
request, channel = self.make_request(
"POST", url.encode("ascii"), access_token=admin_user_tok,
)
self.render(request)
self.pump(1.0)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
json.loads(channel.result["body"].decode("utf-8")),
{"num_quarantined": 2},
"Expected 2 quarantined items",
)
# Attempt to access each piece of media
request, channel = self.make_request(
"GET",
server_and_media_id_1,
shorthand=False,
access_token=non_admin_user_tok,
)
request.render(self.download_resource)
self.pump(1.0)
# Should be quarantined
self.assertEqual(
404,
int(channel.code),
msg=(
"Expected to receive a 404 on accessing quarantined media: %s"
% server_and_media_id_1,
),
)
# Attempt to access each piece of media
request, channel = self.make_request(
"GET",
server_and_media_id_2,
shorthand=False,
access_token=non_admin_user_tok,
)
request.render(self.download_resource)
self.pump(1.0)
# Should be quarantined
self.assertEqual(
404,
int(channel.code),
msg=(
"Expected to receive a 404 on accessing quarantined media: %s"
% server_and_media_id_2
),
)

View file

@ -21,6 +21,8 @@ import time
import attr
from twisted.web.resource import Resource
from synapse.api.constants import Membership
from tests.server import make_request, render
@ -160,3 +162,38 @@ class RestHelper(object):
)
return channel.json_body
def upload_media(
self,
resource: Resource,
image_data: bytes,
tok: str,
filename: str = "test.png",
expect_code: int = 200,
) -> dict:
"""Upload a piece of test media to the media repo
Args:
resource: The resource that will handle the upload request
image_data: The image data to upload
tok: The user token to use during the upload
filename: The filename of the media to be uploaded
expect_code: The return code to expect from attempting to upload the media
"""
image_length = len(image_data)
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
request, channel = make_request(
self.hs.get_reactor(), "POST", path, content=image_data, access_token=tok
)
request.requestHeaders.addRawHeader(
b"Content-Length", str(image_length).encode("UTF-8")
)
request.render(resource)
self.hs.get_reactor().pump([100])
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
int(channel.result["code"]),
channel.result["body"],
)
return channel.json_body

View file

@ -285,7 +285,9 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
)
# Make sure the invite is here.
pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
pending_invites = self.get_success(
store.get_invited_rooms_for_local_user(invitee_id)
)
self.assertEqual(len(pending_invites), 1, pending_invites)
self.assertEqual(pending_invites[0].room_id, room_id, pending_invites)
@ -293,12 +295,16 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
self.deactivate(invitee_id, invitee_tok)
# Check that the invite isn't there anymore.
pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
pending_invites = self.get_success(
store.get_invited_rooms_for_local_user(invitee_id)
)
self.assertEqual(len(pending_invites), 0, pending_invites)
# Check that the membership of @invitee:test in the room is now "leave".
memberships = self.get_success(
store.get_rooms_for_user_where_membership_is(invitee_id, [Membership.LEAVE])
store.get_rooms_for_local_user_where_membership_is(
invitee_id, [Membership.LEAVE]
)
)
self.assertEqual(len(memberships), 1, memberships)
self.assertEqual(memberships[0].room_id, room_id, memberships)

View file

@ -15,8 +15,6 @@
# limitations under the License.
import json
from mock import Mock
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes
from synapse.rest.client.v1 import login, room
@ -36,13 +34,6 @@ class FilterTestCase(unittest.HomeserverTestCase):
sync.register_servlets,
]
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
"red", http_client=None, federation_client=Mock()
)
return hs
def test_sync_argless(self):
request, channel = self.make_request("GET", "/sync")
self.render(request)

View file

@ -57,7 +57,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
rooms_for_user = self.get_success(
self.store.get_rooms_for_user_where_membership_is(
self.store.get_rooms_for_local_user_where_membership_is(
self.u_alice, [Membership.JOIN]
)
)

View file

@ -23,8 +23,12 @@ from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource
from synapse.api.errors import Codes, RedirectException, SynapseError
from synapse.http.server import (
DirectServeResource,
JsonResource,
wrap_html_request_handler,
)
from synapse.http.site import SynapseSite, logger
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock
@ -164,6 +168,77 @@ class JsonResourceTests(unittest.TestCase):
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
class WrapHtmlRequestHandlerTests(unittest.TestCase):
class TestResource(DirectServeResource):
callback = None
@wrap_html_request_handler
async def _async_render_GET(self, request):
return await self.callback(request)
def setUp(self):
self.reactor = ThreadedMemoryReactorClock()
def test_good_response(self):
def callback(request):
request.write(b"response")
request.finish()
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
request, channel = make_request(self.reactor, b"GET", b"/path")
render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"200")
body = channel.result["body"]
self.assertEqual(body, b"response")
def test_redirect_exception(self):
"""
If the callback raises a RedirectException, it is turned into a 30x
with the right location.
"""
def callback(request, **kwargs):
raise RedirectException(b"/look/an/eagle", 301)
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
request, channel = make_request(self.reactor, b"GET", b"/path")
render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"301")
headers = channel.result["headers"]
location_headers = [v for k, v in headers if k == b"Location"]
self.assertEqual(location_headers, [b"/look/an/eagle"])
def test_redirect_exception_with_cookie(self):
"""
If the callback raises a RedirectException which sets a cookie, that is
returned too
"""
def callback(request, **kwargs):
e = RedirectException(b"/no/over/there", 304)
e.cookies.append(b"session=yespls")
raise e
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
request, channel = make_request(self.reactor, b"GET", b"/path")
render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"304")
headers = channel.result["headers"]
location_headers = [v for k, v in headers if k == b"Location"]
self.assertEqual(location_headers, [b"/no/over/there"])
cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
self.assertEqual(cookies_headers, [b"session=yespls"])
class SiteTestCase(unittest.HomeserverTestCase):
def test_lose_connection(self):
"""

View file

@ -181,6 +181,7 @@ commands = mypy \
synapse/handlers/ui_auth \
synapse/logging/ \
synapse/module_api \
synapse/replication \
synapse/rest/consent \
synapse/rest/saml2 \
synapse/spam_checker_api \