Merge branch 'develop' into matthew/dinsic_3pid_check

This commit is contained in:
Matthew Hodgson 2018-03-14 21:56:58 +00:00
commit e3eb2cfe8b
130 changed files with 5334 additions and 3312 deletions

2
.gitignore vendored
View file

@ -46,3 +46,5 @@ static/client/register/register_config.js
env/ env/
*.config *.config
.vscode/

View file

@ -0,0 +1,23 @@
# List all media in a room
This API gets a list of known media in a room.
The API is:
```
GET /_matrix/client/r0/admin/room/<room_id>/media
```
including an `access_token` of a server admin.
It returns a JSON body like the following:
```
{
"local": [
"mxc://localhost/xwvutsrqponmlkjihgfedcba",
"mxc://localhost/abcdefghijklmnopqrstuvwx"
],
"remote": [
"mxc://matrix.org/xwvutsrqponmlkjihgfedcba",
"mxc://matrix.org/abcdefghijklmnopqrstuvwx"
]
}
```

View file

@ -4,14 +4,58 @@ Purge History API
The purge history API allows server admins to purge historic events from their The purge history API allows server admins to purge historic events from their
database, reclaiming disk space. database, reclaiming disk space.
**NB!** This will not delete local events (locally sent messages content etc) from the database, but will remove lots of the metadata about them and does dramatically reduce the on disk space usage
Depending on the amount of history being purged a call to the API may take Depending on the amount of history being purged a call to the API may take
several minutes or longer. During this period users will not be able to several minutes or longer. During this period users will not be able to
paginate further back in the room from the point being purged from. paginate further back in the room from the point being purged from.
The API is simply: The API is:
``POST /_matrix/client/r0/admin/purge_history/<room_id>/<event_id>`` ``POST /_matrix/client/r0/admin/purge_history/<room_id>[/<event_id>]``
including an ``access_token`` of a server admin. including an ``access_token`` of a server admin.
By default, events sent by local users are not deleted, as they may represent
the only copies of this content in existence. (Events sent by remote users are
deleted, and room state data before the cutoff is always removed).
To delete local events as well, set ``delete_local_events`` in the body:
.. code:: json
{
"delete_local_events": true
}
The caller must specify the point in the room to purge up to. This can be
specified by including an event_id in the URI, or by setting a
``purge_up_to_event_id`` or ``purge_up_to_ts`` in the request body. If an event
id is given, that event (and others at the same graph depth) will be retained.
If ``purge_up_to_ts`` is given, it should be a timestamp since the unix epoch,
in milliseconds.
The API starts the purge running, and returns immediately with a JSON body with
a purge id:
.. code:: json
{
"purge_id": "<opaque id>"
}
Purge status query
------------------
It is possible to poll for updates on recent purges with a second API;
``GET /_matrix/client/r0/admin/purge_history_status/<purge_id>``
(again, with a suitable ``access_token``). This API returns a JSON body like
the following:
.. code:: json
{
"status": "active"
}
The status will be one of ``active``, ``complete``, or ``failed``.

View file

@ -279,9 +279,9 @@ Obviously that option means that the operations done in
that might be fixed by setting a different logcontext via a ``with that might be fixed by setting a different logcontext via a ``with
LoggingContext(...)`` in ``background_operation``). LoggingContext(...)`` in ``background_operation``).
The second option is to use ``logcontext.preserve_fn``, which wraps a function The second option is to use ``logcontext.run_in_background``, which wraps a
so that it doesn't reset the logcontext even when it returns an incomplete function so that it doesn't reset the logcontext even when it returns an
deferred, and adds a callback to the returned deferred to reset the incomplete deferred, and adds a callback to the returned deferred to reset the
logcontext. In other words, it turns a function that follows the Synapse rules logcontext. In other words, it turns a function that follows the Synapse rules
about logcontexts and Deferreds into one which behaves more like an external about logcontexts and Deferreds into one which behaves more like an external
function — the opposite operation to that described in the previous section. function — the opposite operation to that described in the previous section.
@ -293,7 +293,7 @@ It can be used like this:
def do_request_handling(): def do_request_handling():
yield foreground_operation() yield foreground_operation()
logcontext.preserve_fn(background_operation)() logcontext.run_in_background(background_operation)
# this will now be logged against the request context # this will now be logged against the request context
logger.debug("Request handling complete") logger.debug("Request handling complete")

View file

@ -30,17 +30,29 @@ requests made to the federation port. The caveats regarding running a
reverse-proxy on the federation port still apply (see reverse-proxy on the federation port still apply (see
https://github.com/matrix-org/synapse/blob/master/README.rst#reverse-proxying-the-federation-port). https://github.com/matrix-org/synapse/blob/master/README.rst#reverse-proxying-the-federation-port).
To enable workers, you need to add a replication listener to the master synapse, e.g.:: To enable workers, you need to add two replication listeners to the master
synapse, e.g.::
listeners: listeners:
# The TCP replication port
- port: 9092 - port: 9092
bind_address: '127.0.0.1' bind_address: '127.0.0.1'
type: replication type: replication
# The HTTP replication port
- port: 9093
bind_address: '127.0.0.1'
type: http
resources:
- names: [replication]
Under **no circumstances** should this replication API listener be exposed to the Under **no circumstances** should these replication API listeners be exposed to
public internet; it currently implements no authentication whatsoever and is the public internet; it currently implements no authentication whatsoever and is
unencrypted. unencrypted.
(Roughly, the TCP port is used for streaming data from the master to the
workers, and the HTTP port for the workers to send data to the main
synapse process.)
You then create a set of configs for the various worker processes. These You then create a set of configs for the various worker processes. These
should be worker configuration files, and should be stored in a dedicated should be worker configuration files, and should be stored in a dedicated
subdirectory, to allow synctl to manipulate them. subdirectory, to allow synctl to manipulate them.
@ -52,8 +64,13 @@ You should minimise the number of overrides though to maintain a usable config.
You must specify the type of worker application (``worker_app``). The currently You must specify the type of worker application (``worker_app``). The currently
available worker applications are listed below. You must also specify the available worker applications are listed below. You must also specify the
replication endpoint that it's talking to on the main synapse process replication endpoints that it's talking to on the main synapse process.
(``worker_replication_host`` and ``worker_replication_port``). ``worker_replication_host`` should specify the host of the main synapse,
``worker_replication_port`` should point to the TCP replication listener port and
``worker_replication_http_port`` should point to the HTTP replication port.
Currently, only the ``event_creator`` worker requires specifying
``worker_replication_http_port``.
For instance:: For instance::
@ -62,6 +79,7 @@ For instance::
# The replication listener on the synapse to talk to. # The replication listener on the synapse to talk to.
worker_replication_host: 127.0.0.1 worker_replication_host: 127.0.0.1
worker_replication_port: 9092 worker_replication_port: 9092
worker_replication_http_port: 9093
worker_listeners: worker_listeners:
- type: http - type: http
@ -207,3 +225,14 @@ the ``worker_main_http_uri`` setting in the frontend_proxy worker configuration
file. For example:: file. For example::
worker_main_http_uri: http://127.0.0.1:8008 worker_main_http_uri: http://127.0.0.1:8008
``synapse.app.event_creator``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Handles non-state event creation. It can handle REST endpoints matching:
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send
It will create events locally and then send them on to the main synapse
instance to be persisted and handled.

View file

@ -0,0 +1,133 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Moves a list of remote media from one media store to another.
The input should be a list of media files to be moved, one per line. Each line
should be formatted::
<origin server>|<file id>
This can be extracted from postgres with::
psql --tuples-only -A -c "select media_origin, filesystem_id from
matrix.remote_media_cache where ..."
To use, pipe the above into::
PYTHON_PATH=. ./scripts/move_remote_media_to_new_store.py <source repo> <dest repo>
"""
from __future__ import print_function
import argparse
import logging
import sys
import os
import shutil
from synapse.rest.media.v1.filepath import MediaFilePaths
logger = logging.getLogger()
def main(src_repo, dest_repo):
src_paths = MediaFilePaths(src_repo)
dest_paths = MediaFilePaths(dest_repo)
for line in sys.stdin:
line = line.strip()
parts = line.split('|')
if len(parts) != 2:
print("Unable to parse input line %s" % line, file=sys.stderr)
exit(1)
move_media(parts[0], parts[1], src_paths, dest_paths)
def move_media(origin_server, file_id, src_paths, dest_paths):
"""Move the given file, and any thumbnails, to the dest repo
Args:
origin_server (str):
file_id (str):
src_paths (MediaFilePaths):
dest_paths (MediaFilePaths):
"""
logger.info("%s/%s", origin_server, file_id)
# check that the original exists
original_file = src_paths.remote_media_filepath(origin_server, file_id)
if not os.path.exists(original_file):
logger.warn(
"Original for %s/%s (%s) does not exist",
origin_server, file_id, original_file,
)
else:
mkdir_and_move(
original_file,
dest_paths.remote_media_filepath(origin_server, file_id),
)
# now look for thumbnails
original_thumb_dir = src_paths.remote_media_thumbnail_dir(
origin_server, file_id,
)
if not os.path.exists(original_thumb_dir):
return
mkdir_and_move(
original_thumb_dir,
dest_paths.remote_media_thumbnail_dir(origin_server, file_id)
)
def mkdir_and_move(original_file, dest_file):
dirname = os.path.dirname(dest_file)
if not os.path.exists(dirname):
logger.debug("mkdir %s", dirname)
os.makedirs(dirname)
logger.debug("mv %s %s", original_file, dest_file)
shutil.move(original_file, dest_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class = argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"-v", action='store_true', help='enable debug logging')
parser.add_argument(
"src_repo",
help="Path to source content repo",
)
parser.add_argument(
"dest_repo",
help="Path to source content repo",
)
args = parser.parse_args()
logging_config = {
"level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s"
}
logging.basicConfig(**logging_config)
main(args.src_repo, args.dest_repo)

View file

@ -49,19 +49,6 @@ class AppserviceSlaveStore(
class AppserviceServer(HomeServer): class AppserviceServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self) self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)

View file

@ -64,19 +64,6 @@ class ClientReaderSlavedStore(
class ClientReaderServer(HomeServer): class ClientReaderServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self) self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)
@ -169,7 +156,6 @@ def start(config_options):
) )
ss.setup() ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners) ss.start_listening(config.worker_listeners)
def start(): def start():

View file

@ -0,0 +1,177 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import sys
import synapse
from synapse import events
from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.crypto import context_factory
from synapse.http.server import JsonResource
from synapse.http.site import SynapseSite
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.pushers import SlavedPusherStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v1.room import RoomSendEventRestServlet
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
from twisted.internet import reactor
from twisted.web.resource import Resource
logger = logging.getLogger("synapse.app.event_creator")
class EventCreatorSlavedStore(
SlavedAccountDataStore,
SlavedPusherStore,
SlavedReceiptsStore,
SlavedPushRuleStore,
SlavedDeviceStore,
SlavedClientIpStore,
SlavedApplicationServiceStore,
SlavedEventStore,
SlavedRegistrationStore,
RoomStore,
BaseSlavedStore,
):
pass
class EventCreatorServer(HomeServer):
def setup(self):
logger.info("Setting up.")
self.datastore = EventCreatorSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "client":
resource = JsonResource(self, canonical_json=False)
RoomSendEventRestServlet(self).register(resource)
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource,
"/_matrix/client/api/v1": resource,
})
root_resource = create_resource_tree(resources, Resource())
_base.listen_tcp(
bind_addresses,
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
)
)
logger.info("Synapse event creator now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
_base.listen_tcp(
listener["bind_addresses"],
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
)
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
def build_tcp_replication(self):
return ReplicationClientHandler(self.get_datastore())
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse event creator", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.event_creator"
assert config.worker_replication_http_port is not None
setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
ss = EventCreatorServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.start_listening(config.worker_listeners)
def start():
ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling()
reactor.callWhenRunning(start)
_base.start_worker_reactor("synapse-event-creator", config)
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View file

@ -58,19 +58,6 @@ class FederationReaderSlavedStore(
class FederationReaderServer(HomeServer): class FederationReaderServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self) self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
@ -157,7 +144,6 @@ def start(config_options):
) )
ss.setup() ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners) ss.start_listening(config.worker_listeners)
def start(): def start():

View file

@ -76,19 +76,6 @@ class FederationSenderSlaveStore(
class FederationSenderServer(HomeServer): class FederationSenderServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self) self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)

View file

@ -118,19 +118,6 @@ class FrontendProxySlavedStore(
class FrontendProxyServer(HomeServer): class FrontendProxyServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self) self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self)
@ -224,7 +211,6 @@ def start(config_options):
) )
ss.setup() ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners) ss.start_listening(config.worker_listeners)
def start(): def start():

View file

@ -38,6 +38,7 @@ from synapse.metrics import register_memory_metrics
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.python_dependencies import CONDITIONAL_REQUIREMENTS, \ from synapse.python_dependencies import CONDITIONAL_REQUIREMENTS, \
check_requirements check_requirements
from synapse.replication.http import ReplicationRestResource, REPLICATION_PREFIX
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.rest import ClientRestResource from synapse.rest import ClientRestResource
from synapse.rest.key.v1.server_key_resource import LocalKey from synapse.rest.key.v1.server_key_resource import LocalKey
@ -219,6 +220,9 @@ class SynapseHomeServer(HomeServer):
if name == "metrics" and self.get_config().enable_metrics: if name == "metrics" and self.get_config().enable_metrics:
resources[METRICS_PREFIX] = MetricsResource(self) resources[METRICS_PREFIX] = MetricsResource(self)
if name == "replication":
resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
return resources return resources
def start_listening(self): def start_listening(self):
@ -266,19 +270,6 @@ class SynapseHomeServer(HomeServer):
except IncorrectDatabaseSetup as e: except IncorrectDatabaseSetup as e:
quit_with_error(e.message) quit_with_error(e.message)
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(config_options): def setup(config_options):
""" """
@ -357,7 +348,7 @@ def setup(config_options):
hs.get_state_handler().start_caching() hs.get_state_handler().start_caching()
hs.get_datastore().start_profiling() hs.get_datastore().start_profiling()
hs.get_datastore().start_doing_background_updates() hs.get_datastore().start_doing_background_updates()
hs.get_replication_layer().start_get_pdu_cache() hs.get_federation_client().start_get_pdu_cache()
register_memory_metrics(hs) register_memory_metrics(hs)

View file

@ -60,19 +60,6 @@ class MediaRepositorySlavedStore(
class MediaRepositoryServer(HomeServer): class MediaRepositoryServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self) self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
@ -171,7 +158,6 @@ def start(config_options):
) )
ss.setup() ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners) ss.start_listening(config.worker_listeners)
def start(): def start():

View file

@ -32,7 +32,6 @@ from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.roommember import RoomMemberStore
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, preserve_fn from synapse.util.logcontext import LoggingContext, preserve_fn
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
@ -75,25 +74,8 @@ class PusherSlaveStore(
DataStore.get_profile_displayname.__func__ DataStore.get_profile_displayname.__func__
) )
who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"]
)
class PusherServer(HomeServer): class PusherServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = PusherSlaveStore(self.get_db_conn(), self) self.datastore = PusherSlaveStore(self.get_db_conn(), self)

View file

@ -62,8 +62,6 @@ logger = logging.getLogger("synapse.app.synchrotron")
class SynchrotronSlavedStore( class SynchrotronSlavedStore(
SlavedPushRuleStore,
SlavedEventStore,
SlavedReceiptsStore, SlavedReceiptsStore,
SlavedAccountDataStore, SlavedAccountDataStore,
SlavedApplicationServiceStore, SlavedApplicationServiceStore,
@ -73,14 +71,12 @@ class SynchrotronSlavedStore(
SlavedGroupServerStore, SlavedGroupServerStore,
SlavedDeviceInboxStore, SlavedDeviceInboxStore,
SlavedDeviceStore, SlavedDeviceStore,
SlavedPushRuleStore,
SlavedEventStore,
SlavedClientIpStore, SlavedClientIpStore,
RoomStore, RoomStore,
BaseSlavedStore, BaseSlavedStore,
): ):
who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"]
)
did_forget = ( did_forget = (
RoomMemberStore.__dict__["did_forget"] RoomMemberStore.__dict__["did_forget"]
) )
@ -246,19 +242,6 @@ class SynchrotronApplicationService(object):
class SynchrotronServer(HomeServer): class SynchrotronServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self) self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)

View file

@ -92,19 +92,6 @@ class UserDirectorySlaveStore(
class UserDirectoryServer(HomeServer): class UserDirectoryServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self): def setup(self):
logger.info("Setting up.") logger.info("Setting up.")
self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self) self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self)

View file

@ -33,8 +33,16 @@ class WorkerConfig(Config):
self.worker_pid_file = config.get("worker_pid_file") self.worker_pid_file = config.get("worker_pid_file")
self.worker_log_file = config.get("worker_log_file") self.worker_log_file = config.get("worker_log_file")
self.worker_log_config = config.get("worker_log_config") self.worker_log_config = config.get("worker_log_config")
# The host used to connect to the main synapse
self.worker_replication_host = config.get("worker_replication_host", None) self.worker_replication_host = config.get("worker_replication_host", None)
# The port on the main synapse for TCP replication
self.worker_replication_port = config.get("worker_replication_port", None) self.worker_replication_port = config.get("worker_replication_port", None)
# The port on the main synapse for HTTP replication endpoint
self.worker_replication_http_port = config.get("worker_replication_http_port")
self.worker_name = config.get("worker_name", self.worker_app) self.worker_name = config.get("worker_name", self.worker_app)
self.worker_main_http_uri = config.get("worker_main_http_uri", None) self.worker_main_http_uri = config.get("worker_main_http_uri", None)

View file

@ -13,6 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from frozendict import frozendict
class EventContext(object): class EventContext(object):
""" """
@ -25,7 +29,9 @@ class EventContext(object):
The current state map excluding the current event. The current state map excluding the current event.
(type, state_key) -> event_id (type, state_key) -> event_id
state_group (int): state group id state_group (int|None): state group id, if the state has been stored
as a state group. This is usually only None if e.g. the event is
an outlier.
rejected (bool|str): A rejection reason if the event was rejected, else rejected (bool|str): A rejection reason if the event was rejected, else
False False
@ -46,7 +52,6 @@ class EventContext(object):
"prev_state_ids", "prev_state_ids",
"state_group", "state_group",
"rejected", "rejected",
"push_actions",
"prev_group", "prev_group",
"delta_ids", "delta_ids",
"prev_state_events", "prev_state_events",
@ -61,7 +66,6 @@ class EventContext(object):
self.state_group = None self.state_group = None
self.rejected = False self.rejected = False
self.push_actions = []
# A previously persisted state group and a delta between that # A previously persisted state group and a delta between that
# and this state. # and this state.
@ -71,3 +75,98 @@ class EventContext(object):
self.prev_state_events = None self.prev_state_events = None
self.app_service = None self.app_service = None
def serialize(self, event):
"""Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize`
Args:
event (FrozenEvent): The event that this context relates to
Returns:
dict
"""
# We don't serialize the full state dicts, instead they get pulled out
# of the DB on the other side. However, the other side can't figure out
# the prev_state_ids, so if we're a state event we include the event
# id that we replaced in the state.
if event.is_state():
prev_state_id = self.prev_state_ids.get((event.type, event.state_key))
else:
prev_state_id = None
return {
"prev_state_id": prev_state_id,
"event_type": event.type,
"event_state_key": event.state_key if event.is_state() else None,
"state_group": self.state_group,
"rejected": self.rejected,
"prev_group": self.prev_group,
"delta_ids": _encode_state_dict(self.delta_ids),
"prev_state_events": self.prev_state_events,
"app_service_id": self.app_service.id if self.app_service else None
}
@staticmethod
@defer.inlineCallbacks
def deserialize(store, input):
"""Converts a dict that was produced by `serialize` back into a
EventContext.
Args:
store (DataStore): Used to convert AS ID to AS object
input (dict): A dict produced by `serialize`
Returns:
EventContext
"""
context = EventContext()
context.state_group = input["state_group"]
context.rejected = input["rejected"]
context.prev_group = input["prev_group"]
context.delta_ids = _decode_state_dict(input["delta_ids"])
context.prev_state_events = input["prev_state_events"]
# We use the state_group and prev_state_id stuff to pull the
# current_state_ids out of the DB and construct prev_state_ids.
prev_state_id = input["prev_state_id"]
event_type = input["event_type"]
event_state_key = input["event_state_key"]
context.current_state_ids = yield store.get_state_ids_for_group(
context.state_group,
)
if prev_state_id and event_state_key:
context.prev_state_ids = dict(context.current_state_ids)
context.prev_state_ids[(event_type, event_state_key)] = prev_state_id
else:
context.prev_state_ids = context.current_state_ids
app_service_id = input["app_service_id"]
if app_service_id:
context.app_service = store.get_app_service_by_id(app_service_id)
defer.returnValue(context)
def _encode_state_dict(state_dict):
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
JSON we need to convert them to a form that can.
"""
if state_dict is None:
return None
return [
(etype, state_key, v)
for (etype, state_key), v in state_dict.iteritems()
]
def _decode_state_dict(input):
"""Decodes a state dict encoded using `_encode_state_dict` above
"""
if input is None:
return None
return frozendict({(etype, state_key,): v for etype, state_key, v in input})

View file

@ -15,11 +15,3 @@
""" This package includes all the federation specific logic. """ This package includes all the federation specific logic.
""" """
from .replication import ReplicationLayer
def initialize_http_replication(hs):
transport = hs.get_federation_transport_client()
return ReplicationLayer(hs, transport)

View file

@ -27,7 +27,13 @@ logger = logging.getLogger(__name__)
class FederationBase(object): class FederationBase(object):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
self.store = hs.get_datastore()
self._clock = hs.get_clock()
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False, def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,

View file

@ -58,6 +58,7 @@ class FederationClient(FederationBase):
self._clear_tried_cache, 60 * 1000, self._clear_tried_cache, 60 * 1000,
) )
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client()
def _clear_tried_cache(self): def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache""" """Clear pdu_destination_tried cache"""

View file

@ -17,12 +17,14 @@ import logging
import simplejson as json import simplejson as json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, FederationError, SynapseError from synapse.api.errors import AuthError, FederationError, SynapseError, NotFoundError
from synapse.crypto.event_signing import compute_event_signature from synapse.crypto.event_signing import compute_event_signature
from synapse.federation.federation_base import ( from synapse.federation.federation_base import (
FederationBase, FederationBase,
event_from_pdu_json, event_from_pdu_json,
) )
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction from synapse.federation.units import Edu, Transaction
import synapse.metrics import synapse.metrics
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
@ -52,50 +54,19 @@ class FederationServer(FederationBase):
super(FederationServer, self).__init__(hs) super(FederationServer, self).__init__(hs)
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.handler = hs.get_handlers().federation_handler
self._server_linearizer = async.Linearizer("fed_server") self._server_linearizer = async.Linearizer("fed_server")
self._transaction_linearizer = async.Linearizer("fed_txn_handler") self._transaction_linearizer = async.Linearizer("fed_txn_handler")
self.transaction_actions = TransactionActions(self.store)
self.registry = hs.get_federation_registry()
# We cache responses to state queries, as they take a while and often # We cache responses to state queries, as they take a while and often
# come in waves. # come in waves.
self._state_resp_cache = ResponseCache(hs, timeout_ms=30000) self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate
receipt of new PDUs from other home servers. The required methods are
documented on :py:class:`.ReplicationHandler`.
"""
self.handler = handler
def register_edu_handler(self, edu_type, handler):
if edu_type in self.edu_handlers:
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
self.edu_handlers[edu_type] = handler
def register_query_handler(self, query_type, handler):
"""Sets the handler callable that will be used to handle an incoming
federation Query of the given type.
Args:
query_type (str): Category name of the query, which should match
the string used by make_query.
handler (callable): Invoked to handle incoming queries of this type
handler is invoked as:
result = handler(args)
where 'args' is a dict mapping strings to strings of the query
arguments. It should return a Deferred that will eventually yield an
object to encode as JSON.
"""
if query_type in self.query_handlers:
raise KeyError(
"Already have a Query handler for %s" % (query_type,)
)
self.query_handlers[query_type] = handler
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_backfill_request(self, origin, room_id, versions, limit): def on_backfill_request(self, origin, room_id, versions, limit):
@ -229,16 +200,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def received_edu(self, origin, edu_type, content): def received_edu(self, origin, edu_type, content):
received_edus_counter.inc() received_edus_counter.inc()
yield self.registry.on_edu(edu_type, origin, content)
if edu_type in self.edu_handlers:
try:
yield self.edu_handlers[edu_type](origin, content)
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception as e:
logger.exception("Failed to handle edu %r", edu_type)
else:
logger.warn("Received EDU of type %s with no handler", edu_type)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -328,14 +290,8 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_query_request(self, query_type, args): def on_query_request(self, query_type, args):
received_queries_counter.inc(query_type) received_queries_counter.inc(query_type)
resp = yield self.registry.on_query(query_type, args)
if query_type in self.query_handlers: defer.returnValue((200, resp))
response = yield self.query_handlers[query_type](args)
defer.returnValue((200, response))
else:
defer.returnValue(
(404, "No handler for Query type '%s'" % (query_type,))
)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_make_join_request(self, room_id, user_id): def on_make_join_request(self, room_id, user_id):
@ -607,3 +563,66 @@ class FederationServer(FederationBase):
origin, room_id, event_dict origin, room_id, event_dict
) )
defer.returnValue(ret) defer.returnValue(ret)
class FederationHandlerRegistry(object):
"""Allows classes to register themselves as handlers for a given EDU or
query type for incoming federation traffic.
"""
def __init__(self):
self.edu_handlers = {}
self.query_handlers = {}
def register_edu_handler(self, edu_type, handler):
"""Sets the handler callable that will be used to handle an incoming
federation EDU of the given type.
Args:
edu_type (str): The type of the incoming EDU to register handler for
handler (Callable[[str, dict]]): A callable invoked on incoming EDU
of the given type. The arguments are the origin server name and
the EDU contents.
"""
if edu_type in self.edu_handlers:
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
self.edu_handlers[edu_type] = handler
def register_query_handler(self, query_type, handler):
"""Sets the handler callable that will be used to handle an incoming
federation query of the given type.
Args:
query_type (str): Category name of the query, which should match
the string used by make_query.
handler (Callable[[dict], Deferred[dict]]): Invoked to handle
incoming queries of this type. The return will be yielded
on and the result used as the response to the query request.
"""
if query_type in self.query_handlers:
raise KeyError(
"Already have a Query handler for %s" % (query_type,)
)
self.query_handlers[query_type] = handler
@defer.inlineCallbacks
def on_edu(self, edu_type, origin, content):
handler = self.edu_handlers.get(edu_type)
if not handler:
logger.warn("No handler registered for EDU type %s", edu_type)
try:
yield handler(origin, content)
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception as e:
logger.exception("Failed to handle edu %r", edu_type)
def on_query(self, query_type, args):
handler = self.query_handlers.get(query_type)
if not handler:
logger.warn("No handler registered for query type %s", query_type)
raise NotFoundError("No handler for Query type '%s'" % (query_type,))
return handler(args)

View file

@ -1,73 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This layer is responsible for replicating with remote home servers using
a given transport.
"""
from .federation_client import FederationClient
from .federation_server import FederationServer
from .persistence import TransactionActions
import logging
logger = logging.getLogger(__name__)
class ReplicationLayer(FederationClient, FederationServer):
"""This layer is responsible for replicating with remote home servers over
the given transport. I.e., does the sending and receiving of PDUs to
remote home servers.
The layer communicates with the rest of the server via a registered
ReplicationHandler.
In more detail, the layer:
* Receives incoming data and processes it into transactions and pdus.
* Fetches any PDUs it thinks it might have missed.
* Keeps the current state for contexts up to date by applying the
suitable conflict resolution.
* Sends outgoing pdus wrapped in transactions.
* Fills out the references to previous pdus/transactions appropriately
for outgoing data.
"""
def __init__(self, hs, transport_layer):
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
self.transport_layer = transport_layer
self.federation_client = self
self.store = hs.get_datastore()
self.handler = None
self.edu_handlers = {}
self.query_handlers = {}
self._clock = hs.get_clock()
self.transaction_actions = TransactionActions(self.store)
self.hs = hs
super(ReplicationLayer, self).__init__(hs)
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name

View file

@ -1190,7 +1190,7 @@ GROUP_ATTESTATION_SERVLET_CLASSES = (
def register_servlets(hs, resource, authenticator, ratelimiter): def register_servlets(hs, resource, authenticator, ratelimiter):
for servletclass in FEDERATION_SERVLET_CLASSES: for servletclass in FEDERATION_SERVLET_CLASSES:
servletclass( servletclass(
handler=hs.get_replication_layer(), handler=hs.get_federation_server(),
authenticator=authenticator, authenticator=authenticator,
ratelimiter=ratelimiter, ratelimiter=ratelimiter,
server_name=hs.hostname, server_name=hs.hostname,

View file

@ -17,7 +17,6 @@ from .register import RegistrationHandler
from .room import ( from .room import (
RoomCreationHandler, RoomContextHandler, RoomCreationHandler, RoomContextHandler,
) )
from .room_member import RoomMemberHandler
from .message import MessageHandler from .message import MessageHandler
from .federation import FederationHandler from .federation import FederationHandler
from .directory import DirectoryHandler from .directory import DirectoryHandler
@ -49,7 +48,6 @@ class Handlers(object):
self.registration_handler = RegistrationHandler(hs) self.registration_handler = RegistrationHandler(hs)
self.message_handler = MessageHandler(hs) self.message_handler = MessageHandler(hs)
self.room_creation_handler = RoomCreationHandler(hs) self.room_creation_handler = RoomCreationHandler(hs)
self.room_member_handler = RoomMemberHandler(hs)
self.federation_handler = FederationHandler(hs) self.federation_handler = FederationHandler(hs)
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)
self.admin_handler = AdminHandler(hs) self.admin_handler = AdminHandler(hs)

View file

@ -158,7 +158,7 @@ class BaseHandler(object):
# homeserver. # homeserver.
requester = synapse.types.create_requester( requester = synapse.types.create_requester(
target_user, is_guest=True) target_user, is_guest=True)
handler = self.hs.get_handlers().room_member_handler handler = self.hs.get_room_member_handler()
yield handler.update_membership( yield handler.update_membership(
requester, requester,
target_user, target_user,

View file

@ -863,8 +863,10 @@ class AuthHandler(BaseHandler):
""" """
def _do_validate_hash(): def _do_validate_hash():
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, return bcrypt.checkpw(
stored_hash.encode('utf8')) == stored_hash password.encode('utf8') + self.hs.config.password_pepper,
stored_hash.encode('utf8')
)
if stored_hash: if stored_hash:
return make_deferred_yieldable(threads.deferToThread(_do_validate_hash)) return make_deferred_yieldable(threads.deferToThread(_do_validate_hash))

View file

@ -37,14 +37,15 @@ class DeviceHandler(BaseHandler):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self.federation_sender = hs.get_federation_sender() self.federation_sender = hs.get_federation_sender()
self.federation = hs.get_replication_layer()
self._edu_updater = DeviceListEduUpdater(hs, self) self._edu_updater = DeviceListEduUpdater(hs, self)
self.federation.register_edu_handler( federation_registry = hs.get_federation_registry()
federation_registry.register_edu_handler(
"m.device_list_update", self._edu_updater.incoming_device_list_update, "m.device_list_update", self._edu_updater.incoming_device_list_update,
) )
self.federation.register_query_handler( federation_registry.register_query_handler(
"user_devices", self.on_federation_query_user_devices, "user_devices", self.on_federation_query_user_devices,
) )
@ -430,7 +431,7 @@ class DeviceListEduUpdater(object):
def __init__(self, hs, device_handler): def __init__(self, hs, device_handler):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.federation = hs.get_replication_layer() self.federation = hs.get_federation_client()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.device_handler = device_handler self.device_handler = device_handler

View file

@ -37,7 +37,7 @@ class DeviceMessageHandler(object):
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.federation = hs.get_federation_sender() self.federation = hs.get_federation_sender()
hs.get_replication_layer().register_edu_handler( hs.get_federation_registry().register_edu_handler(
"m.direct_to_device", self.on_direct_to_device_edu "m.direct_to_device", self.on_direct_to_device_edu
) )

View file

@ -34,9 +34,10 @@ class DirectoryHandler(BaseHandler):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
self.event_creation_handler = hs.get_event_creation_handler()
self.federation = hs.get_replication_layer() self.federation = hs.get_federation_client()
self.federation.register_query_handler( hs.get_federation_registry().register_query_handler(
"directory", self.on_directory_query "directory", self.on_directory_query
) )
@ -249,8 +250,7 @@ class DirectoryHandler(BaseHandler):
def send_room_alias_update_event(self, requester, user_id, room_id): def send_room_alias_update_event(self, requester, user_id, room_id):
aliases = yield self.store.get_aliases_for_room(room_id) aliases = yield self.store.get_aliases_for_room(room_id)
msg_handler = self.hs.get_handlers().message_handler yield self.event_creation_handler.create_and_send_nonmember_event(
yield msg_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.Aliases, "type": EventTypes.Aliases,
@ -272,8 +272,7 @@ class DirectoryHandler(BaseHandler):
if not alias_event or alias_event.content.get("alias", "") != alias_str: if not alias_event or alias_event.content.get("alias", "") != alias_str:
return return
msg_handler = self.hs.get_handlers().message_handler yield self.event_creation_handler.create_and_send_nonmember_event(
yield msg_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.CanonicalAlias, "type": EventTypes.CanonicalAlias,

View file

@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
class E2eKeysHandler(object): class E2eKeysHandler(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.federation = hs.get_replication_layer() self.federation = hs.get_federation_client()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -40,7 +40,7 @@ class E2eKeysHandler(object):
# doesn't really work as part of the generic query API, because the # doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the # query request requires an object POST, but we abuse the
# "query handler" interface. # "query handler" interface.
self.federation.register_query_handler( hs.get_federation_registry().register_query_handler(
"client_keys", self.on_federation_query_client_keys "client_keys", self.on_federation_query_client_keys
) )

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -67,7 +68,7 @@ class FederationHandler(BaseHandler):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.replication_layer = hs.get_replication_layer() self.replication_layer = hs.get_federation_client()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname self.server_name = hs.hostname
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
@ -75,8 +76,7 @@ class FederationHandler(BaseHandler):
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.pusher_pool = hs.get_pusherpool() self.pusher_pool = hs.get_pusherpool()
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self.replication_layer.set_handler(self)
# When joining a room we need to queue any events for that room up # When joining a room we need to queue any events for that room up
self.room_queues = {} self.room_queues = {}
@ -808,13 +808,12 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys()) event_ids = list(extremities.keys())
logger.debug("calling resolve_state_groups in _maybe_backfill") logger.debug("calling resolve_state_groups in _maybe_backfill")
states = yield logcontext.make_deferred_yieldable(defer.gatherResults( resolve = logcontext.preserve_fn(
[ self.state_handler.resolve_state_groups_for_events
logcontext.preserve_fn(self.state_handler.resolve_state_groups)(
room_id, [e]
) )
for e in event_ids states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
], consumeErrors=True, [resolve(room_id, [e]) for e in event_ids],
consumeErrors=True,
)) ))
states = dict(zip(event_ids, [s.state for s in states])) states = dict(zip(event_ids, [s.state for s in states]))
@ -1008,8 +1007,7 @@ class FederationHandler(BaseHandler):
}) })
try: try:
message_handler = self.hs.get_handlers().message_handler event, context = yield self.event_creation_handler.create_new_client_event(
event, context = yield message_handler._create_new_client_event(
builder=builder, builder=builder,
) )
except AuthError as e: except AuthError as e:
@ -1249,8 +1247,7 @@ class FederationHandler(BaseHandler):
"state_key": user_id, "state_key": user_id,
}) })
message_handler = self.hs.get_handlers().message_handler event, context = yield self.event_creation_handler.create_new_client_event(
event, context = yield message_handler._create_new_client_event(
builder=builder, builder=builder,
) )
@ -1448,6 +1445,7 @@ class FederationHandler(BaseHandler):
auth_events=auth_events, auth_events=auth_events,
) )
try:
if not event.internal_metadata.is_outlier() and not backfilled: if not event.internal_metadata.is_outlier() and not backfilled:
yield self.action_generator.handle_push_actions_for_event( yield self.action_generator.handle_push_actions_for_event(
event, context event, context
@ -1458,6 +1456,13 @@ class FederationHandler(BaseHandler):
context=context, context=context,
backfilled=backfilled, backfilled=backfilled,
) )
except: # noqa: E722, as we reraise the exception this is fine.
# Ensure that we actually remove the entries in the push actions
# staging area
logcontext.preserve_fn(
self.store.remove_push_actions_from_staging
)(event.event_id)
raise
if not backfilled: if not backfilled:
# this intentionally does not yield: we don't care about the result # this intentionally does not yield: we don't care about the result
@ -1832,8 +1837,8 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values()) current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state different_auth = event_auth_events - current_state
self._update_context_for_auth_events( yield self._update_context_for_auth_events(
context, auth_events, event_key, event, context, auth_events, event_key,
) )
if different_auth and not event.internal_metadata.is_outlier(): if different_auth and not event.internal_metadata.is_outlier():
@ -1914,8 +1919,8 @@ class FederationHandler(BaseHandler):
# 4. Look at rejects and their proofs. # 4. Look at rejects and their proofs.
# TODO. # TODO.
self._update_context_for_auth_events( yield self._update_context_for_auth_events(
context, auth_events, event_key, event, context, auth_events, event_key,
) )
try: try:
@ -1924,11 +1929,15 @@ class FederationHandler(BaseHandler):
logger.warn("Failed auth resolution for %r because %s", event, e) logger.warn("Failed auth resolution for %r because %s", event, e)
raise e raise e
def _update_context_for_auth_events(self, context, auth_events, @defer.inlineCallbacks
def _update_context_for_auth_events(self, event, context, auth_events,
event_key): event_key):
"""Update the state_ids in an event context after auth event resolution """Update the state_ids in an event context after auth event resolution,
storing the changes as a new state group.
Args: Args:
event (Event): The event we're handling the context for
context (synapse.events.snapshot.EventContext): event context context (synapse.events.snapshot.EventContext): event context
to be updated to be updated
@ -1951,7 +1960,13 @@ class FederationHandler(BaseHandler):
context.prev_state_ids.update({ context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.iteritems() k: a.event_id for k, a in auth_events.iteritems()
}) })
context.state_group = self.store.get_next_state_group() context.state_group = yield self.store.store_state_group(
event.event_id,
event.room_id,
prev_group=context.prev_group,
delta_ids=context.delta_ids,
current_state_ids=context.current_state_ids,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def construct_auth_difference(self, local_auth, remote_auth): def construct_auth_difference(self, local_auth, remote_auth):
@ -2121,8 +2136,7 @@ class FederationHandler(BaseHandler):
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder) EventValidator().validate_new(builder)
message_handler = self.hs.get_handlers().message_handler event, context = yield self.event_creation_handler.create_new_client_event(
event, context = yield message_handler._create_new_client_event(
builder=builder builder=builder
) )
@ -2137,7 +2151,7 @@ class FederationHandler(BaseHandler):
raise e raise e
yield self._check_signature(event, context) yield self._check_signature(event, context)
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context) yield member_handler.send_membership_event(None, event, context)
else: else:
destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
@ -2160,8 +2174,7 @@ class FederationHandler(BaseHandler):
""" """
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
message_handler = self.hs.get_handlers().message_handler event, context = yield self.event_creation_handler.create_new_client_event(
event, context = yield message_handler._create_new_client_event(
builder=builder, builder=builder,
) )
@ -2182,7 +2195,7 @@ class FederationHandler(BaseHandler):
# TODO: Make sure the signatures actually are correct. # TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures) event.signatures.update(returned_invite.signatures)
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context) yield member_handler.send_membership_event(None, event, context)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -2211,8 +2224,9 @@ class FederationHandler(BaseHandler):
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder) EventValidator().validate_new(builder)
message_handler = self.hs.get_handlers().message_handler event, context = yield self.event_creation_handler.create_new_client_event(
event, context = yield message_handler._create_new_client_event(builder=builder) builder=builder,
)
defer.returnValue((event, context)) defer.returnValue((event, context))
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd # Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2017 New Vector Ltd # Copyright 2017 - 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,7 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer, reactor
from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.errors import AuthError, Codes, SynapseError
@ -24,10 +25,12 @@ from synapse.types import (
UserID, RoomAlias, RoomStreamToken, UserID, RoomAlias, RoomStreamToken,
) )
from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn, run_in_background
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.util.frozenutils import unfreeze from synapse.util.frozenutils import unfreeze
from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from synapse.replication.http.send_event import send_event_to_master
from ._base import BaseHandler from ._base import BaseHandler
@ -40,6 +43,36 @@ import ujson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PurgeStatus(object):
"""Object tracking the status of a purge request
This class contains information on the progress of a purge request, for
return by get_purge_status.
Attributes:
status (int): Tracks whether this request has completed. One of
STATUS_{ACTIVE,COMPLETE,FAILED}
"""
STATUS_ACTIVE = 0
STATUS_COMPLETE = 1
STATUS_FAILED = 2
STATUS_TEXT = {
STATUS_ACTIVE: "active",
STATUS_COMPLETE: "complete",
STATUS_FAILED: "failed",
}
def __init__(self):
self.status = PurgeStatus.STATUS_ACTIVE
def asdict(self):
return {
"status": PurgeStatus.STATUS_TEXT[self.status]
}
class MessageHandler(BaseHandler): class MessageHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
@ -47,32 +80,89 @@ class MessageHandler(BaseHandler):
self.hs = hs self.hs = hs
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.validator = EventValidator()
self.profile_handler = hs.get_profile_handler()
self.pagination_lock = ReadWriteLock() self.pagination_lock = ReadWriteLock()
self._purges_in_progress_by_room = set()
# map from purge id to PurgeStatus
self._purges_by_id = {}
self.pusher_pool = hs.get_pusherpool() def start_purge_history(self, room_id, topological_ordering,
delete_local_events=False):
"""Start off a history purge on a room.
# We arbitrarily limit concurrent event creation for a room to 5. Args:
# This is to stop us from diverging history *too* much. room_id (str): The room to purge from
self.limiter = Limiter(max_count=5)
self.action_generator = hs.get_action_generator() topological_ordering (int): minimum topo ordering to preserve
delete_local_events (bool): True to delete local events as well as
remote ones
self.spam_checker = hs.get_spam_checker() Returns:
str: unique ID for this purge transaction.
"""
if room_id in self._purges_in_progress_by_room:
raise SynapseError(
400,
"History purge already in progress for %s" % (room_id, ),
)
purge_id = random_string(16)
# we log the purge_id here so that it can be tied back to the
# request id in the log lines.
logger.info("[purge] starting purge_id %s", purge_id)
self._purges_by_id[purge_id] = PurgeStatus()
run_in_background(
self._purge_history,
purge_id, room_id, topological_ordering, delete_local_events,
)
return purge_id
@defer.inlineCallbacks @defer.inlineCallbacks
def purge_history(self, room_id, event_id): def _purge_history(self, purge_id, room_id, topological_ordering,
event = yield self.store.get_event(event_id) delete_local_events):
"""Carry out a history purge on a room.
if event.room_id != room_id: Args:
raise SynapseError(400, "Event is for wrong room.") purge_id (str): The id for this purge
room_id (str): The room to purge from
depth = event.depth topological_ordering (int): minimum topo ordering to preserve
delete_local_events (bool): True to delete local events as well as
remote ones
Returns:
Deferred
"""
self._purges_in_progress_by_room.add(room_id)
try:
with (yield self.pagination_lock.write(room_id)): with (yield self.pagination_lock.write(room_id)):
yield self.store.delete_old_state(room_id, depth) yield self.store.purge_history(
room_id, topological_ordering, delete_local_events,
)
logger.info("[purge] complete")
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE
except Exception:
logger.error("[purge] failed: %s", Failure().getTraceback().rstrip())
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
finally:
self._purges_in_progress_by_room.discard(room_id)
# remove the purge from the list 24 hours after it completes
def clear_purge():
del self._purges_by_id[purge_id]
reactor.callLater(24 * 3600, clear_purge)
def get_purge_status(self, purge_id):
"""Get the current status of an active purge
Args:
purge_id (str): purge_id returned by start_purge_history
Returns:
PurgeStatus|None
"""
return self._purges_by_id.get(purge_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_messages(self, requester, room_id=None, pagin_config=None, def get_messages(self, requester, room_id=None, pagin_config=None,
@ -182,166 +272,6 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
@defer.inlineCallbacks
def create_event(self, requester, event_dict, token_id=None, txn_id=None,
prev_event_ids=None):
"""
Given a dict from a client, create a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events,
etc.
Adds display names to Join membership events.
Args:
requester
event_dict (dict): An entire event
token_id (str)
txn_id (str)
prev_event_ids (list): The prev event ids to use when creating the event
Returns:
Tuple of created event (FrozenEvent), Context
"""
builder = self.event_builder_factory.new(event_dict)
with (yield self.limiter.queue(builder.room_id)):
self.validator.validate_new(builder)
if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key)
if membership in {Membership.JOIN, Membership.INVITE}:
# If event doesn't include a display name, add one.
profile = self.profile_handler
content = builder.content
try:
if "displayname" not in content:
content["displayname"] = yield profile.get_displayname(target)
if "avatar_url" not in content:
content["avatar_url"] = yield profile.get_avatar_url(target)
except Exception as e:
logger.info(
"Failed to get profile information for %r: %s",
target, e
)
if token_id is not None:
builder.internal_metadata.token_id = token_id
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id
event, context = yield self._create_new_client_event(
builder=builder,
requester=requester,
prev_event_ids=prev_event_ids,
)
defer.returnValue((event, context))
@defer.inlineCallbacks
def send_nonmember_event(self, requester, event, context, ratelimit=True):
"""
Persists and notifies local clients and federation of an event.
Args:
event (FrozenEvent) the event to send.
context (Context) the context of the event.
ratelimit (bool): Whether to rate limit this send.
is_guest (bool): Whether the sender is a guest.
"""
if event.type == EventTypes.Member:
raise SynapseError(
500,
"Tried to send member event through non-member codepath"
)
# We check here if we are currently being rate limited, so that we
# don't do unnecessary work. We check again just before we actually
# send the event.
yield self.ratelimit(requester, update=False)
user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state():
prev_state = yield self.deduplicate_state_event(event, context)
if prev_state is not None:
defer.returnValue(prev_state)
yield self.handle_new_client_event(
requester=requester,
event=event,
context=context,
ratelimit=ratelimit,
)
if event.type == EventTypes.Message:
presence = self.hs.get_presence_handler()
# We don't want to block sending messages on any presence code. This
# matters as sometimes presence code can take a while.
preserve_fn(presence.bump_presence_active_time)(user)
@defer.inlineCallbacks
def deduplicate_state_event(self, event, context):
"""
Checks whether event is in the latest resolved state in context.
If so, returns the version of the event in context.
Otherwise, returns None.
"""
prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
defer.returnValue(prev_event)
return
@defer.inlineCallbacks
def create_and_send_nonmember_event(
self,
requester,
event_dict,
ratelimit=True,
txn_id=None
):
"""
Creates an event, then sends it.
See self.create_event and self.send_nonmember_event.
"""
event, context = yield self.create_event(
requester,
event_dict,
token_id=requester.access_token_id,
txn_id=txn_id
)
spam_error = self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, basestring):
spam_error = "Spam is not permitted here"
raise SynapseError(
403, spam_error, Codes.FORBIDDEN
)
yield self.send_nonmember_event(
requester,
event,
context,
ratelimit=ratelimit,
)
defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_data(self, user_id=None, room_id=None, def get_room_data(self, user_id=None, room_id=None,
event_type=None, state_key="", is_guest=False): event_type=None, state_key="", is_guest=False):
@ -470,9 +400,189 @@ class MessageHandler(BaseHandler):
for user_id, profile in users_with_profile.iteritems() for user_id, profile in users_with_profile.iteritems()
}) })
@measure_func("_create_new_client_event")
class EventCreationHandler(object):
def __init__(self, hs):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
self.profile_handler = hs.get_profile_handler()
self.event_builder_factory = hs.get_event_builder_factory()
self.server_name = hs.hostname
self.ratelimiter = hs.get_ratelimiter()
self.notifier = hs.get_notifier()
self.config = hs.config
self.http_client = hs.get_simple_http_client()
# This is only used to get at ratelimit function, and maybe_kick_guest_users
self.base_handler = BaseHandler(hs)
self.pusher_pool = hs.get_pusherpool()
# We arbitrarily limit concurrent event creation for a room to 5.
# This is to stop us from diverging history *too* much.
self.limiter = Limiter(max_count=5)
self.action_generator = hs.get_action_generator()
self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_new_client_event(self, builder, requester=None, prev_event_ids=None): def create_event(self, requester, event_dict, token_id=None, txn_id=None,
prev_event_ids=None):
"""
Given a dict from a client, create a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events,
etc.
Adds display names to Join membership events.
Args:
requester
event_dict (dict): An entire event
token_id (str)
txn_id (str)
prev_event_ids (list): The prev event ids to use when creating the event
Returns:
Tuple of created event (FrozenEvent), Context
"""
builder = self.event_builder_factory.new(event_dict)
with (yield self.limiter.queue(builder.room_id)):
self.validator.validate_new(builder)
if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key)
if membership in {Membership.JOIN, Membership.INVITE}:
# If event doesn't include a display name, add one.
profile = self.profile_handler
content = builder.content
try:
if "displayname" not in content:
content["displayname"] = yield profile.get_displayname(target)
if "avatar_url" not in content:
content["avatar_url"] = yield profile.get_avatar_url(target)
except Exception as e:
logger.info(
"Failed to get profile information for %r: %s",
target, e
)
if token_id is not None:
builder.internal_metadata.token_id = token_id
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id
event, context = yield self.create_new_client_event(
builder=builder,
requester=requester,
prev_event_ids=prev_event_ids,
)
defer.returnValue((event, context))
@defer.inlineCallbacks
def send_nonmember_event(self, requester, event, context, ratelimit=True):
"""
Persists and notifies local clients and federation of an event.
Args:
event (FrozenEvent) the event to send.
context (Context) the context of the event.
ratelimit (bool): Whether to rate limit this send.
is_guest (bool): Whether the sender is a guest.
"""
if event.type == EventTypes.Member:
raise SynapseError(
500,
"Tried to send member event through non-member codepath"
)
user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state():
prev_state = yield self.deduplicate_state_event(event, context)
if prev_state is not None:
defer.returnValue(prev_state)
yield self.handle_new_client_event(
requester=requester,
event=event,
context=context,
ratelimit=ratelimit,
)
@defer.inlineCallbacks
def deduplicate_state_event(self, event, context):
"""
Checks whether event is in the latest resolved state in context.
If so, returns the version of the event in context.
Otherwise, returns None.
"""
prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
defer.returnValue(prev_event)
return
@defer.inlineCallbacks
def create_and_send_nonmember_event(
self,
requester,
event_dict,
ratelimit=True,
txn_id=None
):
"""
Creates an event, then sends it.
See self.create_event and self.send_nonmember_event.
"""
event, context = yield self.create_event(
requester,
event_dict,
token_id=requester.access_token_id,
txn_id=txn_id
)
spam_error = self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, basestring):
spam_error = "Spam is not permitted here"
raise SynapseError(
403, spam_error, Codes.FORBIDDEN
)
yield self.send_nonmember_event(
requester,
event,
context,
ratelimit=ratelimit,
)
defer.returnValue(event)
@measure_func("create_new_client_event")
@defer.inlineCallbacks
def create_new_client_event(self, builder, requester=None, prev_event_ids=None):
if prev_event_ids: if prev_event_ids:
prev_events = yield self.store.add_event_hashes(prev_event_ids) prev_events = yield self.store.add_event_hashes(prev_event_ids)
prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids) prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
@ -509,9 +619,7 @@ class MessageHandler(BaseHandler):
builder.prev_events = prev_events builder.prev_events = prev_events
builder.depth = depth builder.depth = depth
state_handler = self.state_handler context = yield self.state.compute_event_context(builder)
context = yield state_handler.compute_event_context(builder)
if requester: if requester:
context.app_service = requester.app_service context.app_service = requester.app_service
@ -546,12 +654,21 @@ class MessageHandler(BaseHandler):
event, event,
context, context,
ratelimit=True, ratelimit=True,
extra_users=[] extra_users=[],
): ):
# We now need to go and hit out to wherever we need to hit out to. """Processes a new event. This includes checking auth, persisting it,
notifying users, sending to remote servers, etc.
if ratelimit: If called from a worker will hit out to the master process for final
yield self.ratelimit(requester) processing.
Args:
requester (Requester)
event (FrozenEvent)
context (EventContext)
ratelimit (bool)
extra_users (list(UserID)): Any extra users to notify about event
"""
try: try:
yield self.auth.check_from_context(event, context) yield self.auth.check_from_context(event, context)
@ -567,7 +684,58 @@ class MessageHandler(BaseHandler):
logger.exception("Failed to encode content: %r", event.content) logger.exception("Failed to encode content: %r", event.content)
raise raise
yield self.maybe_kick_guest_users(event, context) yield self.action_generator.handle_push_actions_for_event(
event, context
)
try:
# If we're a worker we need to hit out to the master.
if self.config.worker_app:
yield send_event_to_master(
self.http_client,
host=self.config.worker_replication_host,
port=self.config.worker_replication_http_port,
requester=requester,
event=event,
context=context,
ratelimit=ratelimit,
extra_users=extra_users,
)
return
yield self.persist_and_notify_client_event(
requester,
event,
context,
ratelimit=ratelimit,
extra_users=extra_users,
)
except: # noqa: E722, as we reraise the exception this is fine.
# Ensure that we actually remove the entries in the push actions
# staging area, if we calculated them.
preserve_fn(self.store.remove_push_actions_from_staging)(event.event_id)
raise
@defer.inlineCallbacks
def persist_and_notify_client_event(
self,
requester,
event,
context,
ratelimit=True,
extra_users=[],
):
"""Called when we have fully built the event, have already
calculated the push actions for the event, and checked auth.
This should only be run on master.
"""
assert not self.config.worker_app
if ratelimit:
yield self.base_handler.ratelimit(requester)
yield self.base_handler.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias: if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least) # Check the alias is acually valid (at this time at least)
@ -660,10 +828,6 @@ class MessageHandler(BaseHandler):
"Changing the room create event is forbidden", "Changing the room create event is forbidden",
) )
yield self.action_generator.handle_push_actions_for_event(
event, context
)
(event_stream_id, max_stream_id) = yield self.store.persist_event( (event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context event, context=context
) )
@ -683,3 +847,9 @@ class MessageHandler(BaseHandler):
) )
preserve_fn(_notify)() preserve_fn(_notify)()
if event.type == EventTypes.Message:
presence = self.hs.get_presence_handler()
# We don't want to block sending messages on any presence code. This
# matters as sometimes presence code can take a while.
preserve_fn(presence.bump_presence_active_time)(requester.user)

View file

@ -93,29 +93,30 @@ class PresenceHandler(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.wheel_timer = WheelTimer() self.wheel_timer = WheelTimer()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.replication = hs.get_replication_layer()
self.federation = hs.get_federation_sender() self.federation = hs.get_federation_sender()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.replication.register_edu_handler( federation_registry = hs.get_federation_registry()
federation_registry.register_edu_handler(
"m.presence", self.incoming_presence "m.presence", self.incoming_presence
) )
self.replication.register_edu_handler( federation_registry.register_edu_handler(
"m.presence_invite", "m.presence_invite",
lambda origin, content: self.invite_presence( lambda origin, content: self.invite_presence(
observed_user=UserID.from_string(content["observed_user"]), observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]), observer_user=UserID.from_string(content["observer_user"]),
) )
) )
self.replication.register_edu_handler( federation_registry.register_edu_handler(
"m.presence_accept", "m.presence_accept",
lambda origin, content: self.accept_presence( lambda origin, content: self.accept_presence(
observed_user=UserID.from_string(content["observed_user"]), observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]), observer_user=UserID.from_string(content["observer_user"]),
) )
) )
self.replication.register_edu_handler( federation_registry.register_edu_handler(
"m.presence_deny", "m.presence_deny",
lambda origin, content: self.deny_presence( lambda origin, content: self.deny_presence(
observed_user=UserID.from_string(content["observed_user"]), observed_user=UserID.from_string(content["observed_user"]),

View file

@ -31,8 +31,8 @@ class ProfileHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(ProfileHandler, self).__init__(hs) super(ProfileHandler, self).__init__(hs)
self.federation = hs.get_replication_layer() self.federation = hs.get_federation_client()
self.federation.register_query_handler( hs.get_federation_registry().register_query_handler(
"profile", self.on_profile_query "profile", self.on_profile_query
) )
@ -233,7 +233,7 @@ class ProfileHandler(BaseHandler):
) )
for room_id in room_ids: for room_id in room_ids:
handler = self.hs.get_handlers().room_member_handler handler = self.hs.get_room_member_handler()
try: try:
# Assume the target_user isn't a guest, # Assume the target_user isn't a guest,
# because we don't let guests set profile or avatar data. # because we don't let guests set profile or avatar data.

View file

@ -41,9 +41,9 @@ class ReadMarkerHandler(BaseHandler):
""" """
with (yield self.read_marker_linearizer.queue((room_id, user_id))): with (yield self.read_marker_linearizer.queue((room_id, user_id))):
account_data = yield self.store.get_account_data_for_room(user_id, room_id) existing_read_marker = yield self.store.get_account_data_for_room_and_type(
user_id, room_id, "m.fully_read",
existing_read_marker = account_data.get("m.fully_read", None) )
should_update = True should_update = True

View file

@ -35,7 +35,7 @@ class ReceiptsHandler(BaseHandler):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.hs = hs self.hs = hs
self.federation = hs.get_federation_sender() self.federation = hs.get_federation_sender()
hs.get_replication_layer().register_edu_handler( hs.get_federation_registry().register_edu_handler(
"m.receipt", self._received_remote_receipt "m.receipt", self._received_remote_receipt
) )
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()

View file

@ -448,16 +448,34 @@ class RegistrationHandler(BaseHandler):
return self.hs.get_auth_handler() return self.hs.get_auth_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def guest_access_token_for(self, medium, address, inviter_user_id): def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
"""Get a guest access token for a 3PID, creating a guest account if
one doesn't already exist.
Args:
medium (str)
address (str)
inviter_user_id (str): The user ID who is trying to invite the
3PID
Returns:
Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
3PID guest account.
"""
access_token = yield self.store.get_3pid_guest_access_token(medium, address) access_token = yield self.store.get_3pid_guest_access_token(medium, address)
if access_token: if access_token:
defer.returnValue(access_token) user_info = yield self.auth.get_user_by_access_token(
access_token
)
_, access_token = yield self.register( defer.returnValue((user_info["user"].to_string(), access_token))
user_id, access_token = yield self.register(
generate_token=True, generate_token=True,
make_guest=True make_guest=True
) )
access_token = yield self.store.save_or_get_3pid_guest_access_token( access_token = yield self.store.save_or_get_3pid_guest_access_token(
medium, address, access_token, inviter_user_id medium, address, access_token, inviter_user_id
) )
defer.returnValue(access_token)
defer.returnValue((user_id, access_token))

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd # Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -64,6 +65,7 @@ class RoomCreationHandler(BaseHandler):
super(RoomCreationHandler, self).__init__(hs) super(RoomCreationHandler, self).__init__(hs)
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def create_room(self, requester, config, ratelimit=True): def create_room(self, requester, config, ratelimit=True):
@ -163,13 +165,11 @@ class RoomCreationHandler(BaseHandler):
creation_content = config.get("creation_content", {}) creation_content = config.get("creation_content", {})
msg_handler = self.hs.get_handlers().message_handler room_member_handler = self.hs.get_room_member_handler()
room_member_handler = self.hs.get_handlers().room_member_handler
yield self._send_events_for_new_room( yield self._send_events_for_new_room(
requester, requester,
room_id, room_id,
msg_handler,
room_member_handler, room_member_handler,
preset_config=preset_config, preset_config=preset_config,
invite_list=invite_list, invite_list=invite_list,
@ -181,7 +181,7 @@ class RoomCreationHandler(BaseHandler):
if "name" in config: if "name" in config:
name = config["name"] name = config["name"]
yield msg_handler.create_and_send_nonmember_event( yield self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.Name, "type": EventTypes.Name,
@ -194,7 +194,7 @@ class RoomCreationHandler(BaseHandler):
if "topic" in config: if "topic" in config:
topic = config["topic"] topic = config["topic"]
yield msg_handler.create_and_send_nonmember_event( yield self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.Topic, "type": EventTypes.Topic,
@ -224,7 +224,7 @@ class RoomCreationHandler(BaseHandler):
id_server = invite_3pid["id_server"] id_server = invite_3pid["id_server"]
address = invite_3pid["address"] address = invite_3pid["address"]
medium = invite_3pid["medium"] medium = invite_3pid["medium"]
yield self.hs.get_handlers().room_member_handler.do_3pid_invite( yield self.hs.get_room_member_handler().do_3pid_invite(
room_id, room_id,
requester.user, requester.user,
medium, medium,
@ -249,7 +249,6 @@ class RoomCreationHandler(BaseHandler):
self, self,
creator, # A Requester object. creator, # A Requester object.
room_id, room_id,
msg_handler,
room_member_handler, room_member_handler,
preset_config, preset_config,
invite_list, invite_list,
@ -272,7 +271,7 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def send(etype, content, **kwargs): def send(etype, content, **kwargs):
event = create(etype, content, **kwargs) event = create(etype, content, **kwargs)
yield msg_handler.create_and_send_nonmember_event( yield self.event_creation_handler.create_and_send_nonmember_event(
creator, creator,
event, event,
ratelimit=False ratelimit=False
@ -476,12 +475,9 @@ class RoomEventSource(object):
user.to_string() user.to_string()
) )
if app_service: if app_service:
events, end_key = yield self.store.get_appservice_room_stream( # We no longer support AS users using /sync directly.
service=app_service, # See https://github.com/matrix-org/matrix-doc/issues/1144
from_key=from_key, raise NotImplementedError()
to_key=to_key,
limit=limit,
)
else: else:
room_events = yield self.store.get_membership_changes_for_user( room_events = yield self.store.get_membership_changes_for_user(
user.to_string(), from_key, to_key user.to_string(), from_key, to_key

View file

@ -203,7 +203,8 @@ class RoomListHandler(BaseHandler):
if limit: if limit:
step = limit + 1 step = limit + 1
else: else:
step = len(rooms_to_scan) # step cannot be zero
step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1
chunk = [] chunk = []
for i in xrange(0, len(rooms_to_scan), step): for i in xrange(0, len(rooms_to_scan), step):
@ -408,7 +409,7 @@ class RoomListHandler(BaseHandler):
def _get_remote_list_cached(self, server_name, limit=None, since_token=None, def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
search_filter=None, include_all_networks=False, search_filter=None, include_all_networks=False,
third_party_instance_id=None,): third_party_instance_id=None,):
repl_layer = self.hs.get_replication_layer() repl_layer = self.hs.get_federation_client()
if search_filter: if search_filter:
# We can't cache when asking for search # We can't cache when asking for search
return repl_layer.get_public_rooms( return repl_layer.get_public_rooms(

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -29,23 +30,31 @@ from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.types import UserID, RoomID from synapse.types import UserID, RoomID
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.distributor import user_left_room, user_joined_room from synapse.util.distributor import user_left_room, user_joined_room
from ._base import BaseHandler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
id_server_scheme = "https://" id_server_scheme = "https://"
class RoomMemberHandler(BaseHandler): class RoomMemberHandler(object):
# TODO(paul): This handler currently contains a messy conflation of # TODO(paul): This handler currently contains a messy conflation of
# low-level API that works on UserID objects and so on, and REST-level # low-level API that works on UserID objects and so on, and REST-level
# API that takes ID strings and returns pagination chunks. These concerns # API that takes ID strings and returns pagination chunks. These concerns
# ought to be separated out a lot better. # ought to be separated out a lot better.
def __init__(self, hs): def __init__(self, hs):
super(RoomMemberHandler, self).__init__(hs) self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.state_handler = hs.get_state_handler()
self.config = hs.config
self.simple_http_client = hs.get_simple_http_client()
self.federation_handler = hs.get_handlers().federation_handler
self.directory_handler = hs.get_handlers().directory_handler
self.registration_handler = hs.get_handlers().registration_handler
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.event_creation_hander = hs.get_event_creation_handler()
self.member_linearizer = Linearizer(name="member") self.member_linearizer = Linearizer(name="member")
@ -66,13 +75,12 @@ class RoomMemberHandler(BaseHandler):
): ):
if content is None: if content is None:
content = {} content = {}
msg_handler = self.hs.get_handlers().message_handler
content["membership"] = membership content["membership"] = membership
if requester.is_guest: if requester.is_guest:
content["kind"] = "guest" content["kind"] = "guest"
event, context = yield msg_handler.create_event( event, context = yield self.event_creation_hander.create_event(
requester, requester,
{ {
"type": EventTypes.Member, "type": EventTypes.Member,
@ -90,12 +98,14 @@ class RoomMemberHandler(BaseHandler):
) )
# Check if this event matches the previous membership event for the user. # Check if this event matches the previous membership event for the user.
duplicate = yield msg_handler.deduplicate_state_event(event, context) duplicate = yield self.event_creation_hander.deduplicate_state_event(
event, context,
)
if duplicate is not None: if duplicate is not None:
# Discard the new event since this membership change is a no-op. # Discard the new event since this membership change is a no-op.
defer.returnValue(duplicate) defer.returnValue(duplicate)
yield msg_handler.handle_new_client_event( yield self.event_creation_hander.handle_new_client_event(
requester, requester,
event, event,
context, context,
@ -127,7 +137,20 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
def remote_join(self, remote_room_hosts, room_id, user, content): def _remote_join(self, remote_room_hosts, room_id, user, content):
"""Try and join a room that this server is not in
Args:
remote_room_hosts (list[str]): List of servers that can be used
to join via.
room_id (str): Room that we are trying to join
user (UserID): User who is trying to join
content (dict): A dict that should be used as the content of the
join event.
Returns:
Deferred
"""
if len(remote_room_hosts) == 0: if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers") raise SynapseError(404, "No known servers")
@ -135,7 +158,7 @@ class RoomMemberHandler(BaseHandler):
# join dance for now, since we're kinda implicitly checking # join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we # that we are allowed to join when we decide whether or not we
# need to do the invite/join dance. # need to do the invite/join dance.
yield self.hs.get_handlers().federation_handler.do_invite_join( yield self.federation_handler.do_invite_join(
remote_room_hosts, remote_room_hosts,
room_id, room_id,
user.to_string(), user.to_string(),
@ -143,6 +166,43 @@ class RoomMemberHandler(BaseHandler):
) )
yield user_joined_room(self.distributor, user, room_id) yield user_joined_room(self.distributor, user, room_id)
@defer.inlineCallbacks
def _remote_reject_invite(self, remote_room_hosts, room_id, target):
"""Attempt to reject an invite for a room this server is not in. If we
fail to do so we locally mark the invite as rejected.
Args:
remote_room_hosts (list[str]): List of servers to use to try and
reject invite
room_id (str)
target (UserID): The user rejecting the invite
Returns:
Deferred[dict]: A dictionary to be returned to the client, may
include event_id etc, or nothing if we locally rejected
"""
fed_handler = self.federation_handler
try:
ret = yield fed_handler.do_remotely_reject_invite(
remote_room_hosts,
room_id,
target.to_string(),
)
defer.returnValue(ret)
except Exception as e:
# if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead.
#
# The 'except' clause is very broad, but we need to
# capture everything from DNS failures upwards
#
logger.warn("Failed to reject invite: %s", e)
yield self.store.locally_reject_invite(
target.to_string(), room_id
)
defer.returnValue({})
@defer.inlineCallbacks @defer.inlineCallbacks
def update_membership( def update_membership(
self, self,
@ -201,8 +261,7 @@ class RoomMemberHandler(BaseHandler):
# if this is a join with a 3pid signature, we may need to turn a 3pid # if this is a join with a 3pid signature, we may need to turn a 3pid
# invite into a normal invite before we can handle the join. # invite into a normal invite before we can handle the join.
if third_party_signed is not None: if third_party_signed is not None:
replication = self.hs.get_replication_layer() yield self.federation_handler.exchange_third_party_invite(
yield replication.exchange_third_party_invite(
third_party_signed["sender"], third_party_signed["sender"],
target.to_string(), target.to_string(),
room_id, room_id,
@ -223,7 +282,7 @@ class RoomMemberHandler(BaseHandler):
requester.user, requester.user,
) )
if not is_requester_admin: if not is_requester_admin:
if self.hs.config.block_non_admin_invites: if self.config.block_non_admin_invites:
logger.info( logger.info(
"Blocking invite: user is not admin and non-admin " "Blocking invite: user is not admin and non-admin "
"invites disabled" "invites disabled"
@ -282,7 +341,7 @@ class RoomMemberHandler(BaseHandler):
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
if not is_host_in_room: if not is_host_in_room:
inviter = yield self.get_inviter(target.to_string(), room_id) inviter = yield self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter): if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain) remote_room_hosts.append(inviter.domain)
@ -296,7 +355,7 @@ class RoomMemberHandler(BaseHandler):
if requester.is_guest: if requester.is_guest:
content["kind"] = "guest" content["kind"] = "guest"
ret = yield self.remote_join( ret = yield self._remote_join(
remote_room_hosts, room_id, target, content remote_room_hosts, room_id, target, content
) )
defer.returnValue(ret) defer.returnValue(ret)
@ -304,7 +363,7 @@ class RoomMemberHandler(BaseHandler):
elif effective_membership_state == Membership.LEAVE: elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room: if not is_host_in_room:
# perhaps we've been invited # perhaps we've been invited
inviter = yield self.get_inviter(target.to_string(), room_id) inviter = yield self._get_inviter(target.to_string(), room_id)
if not inviter: if not inviter:
raise SynapseError(404, "Not a known room") raise SynapseError(404, "Not a known room")
@ -318,28 +377,10 @@ class RoomMemberHandler(BaseHandler):
else: else:
# send the rejection to the inviter's HS. # send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain] remote_room_hosts = remote_room_hosts + [inviter.domain]
fed_handler = self.hs.get_handlers().federation_handler res = yield self._remote_reject_invite(
try: remote_room_hosts, room_id, target,
ret = yield fed_handler.do_remotely_reject_invite(
remote_room_hosts,
room_id,
target.to_string(),
) )
defer.returnValue(ret) defer.returnValue(res)
except Exception as e:
# if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead.
#
# The 'except' clause is very broad, but we need to
# capture everything from DNS failures upwards
#
logger.warn("Failed to reject invite: %s", e)
yield self.store.locally_reject_invite(
target.to_string(), room_id
)
defer.returnValue({})
res = yield self._local_membership_update( res = yield self._local_membership_update(
requester=requester, requester=requester,
@ -394,8 +435,9 @@ class RoomMemberHandler(BaseHandler):
else: else:
requester = synapse.types.create_requester(target_user) requester = synapse.types.create_requester(target_user)
message_handler = self.hs.get_handlers().message_handler prev_event = yield self.event_creation_hander.deduplicate_state_event(
prev_event = yield message_handler.deduplicate_state_event(event, context) event, context,
)
if prev_event is not None: if prev_event is not None:
return return
@ -412,7 +454,7 @@ class RoomMemberHandler(BaseHandler):
if is_blocked: if is_blocked:
raise SynapseError(403, "This room has been blocked on this server") raise SynapseError(403, "This room has been blocked on this server")
yield message_handler.handle_new_client_event( yield self.event_creation_hander.handle_new_client_event(
requester, requester,
event, event,
context, context,
@ -473,7 +515,7 @@ class RoomMemberHandler(BaseHandler):
Raises: Raises:
SynapseError if room alias could not be found. SynapseError if room alias could not be found.
""" """
directory_handler = self.hs.get_handlers().directory_handler directory_handler = self.directory_handler
mapping = yield directory_handler.get_association(room_alias) mapping = yield directory_handler.get_association(room_alias)
if not mapping: if not mapping:
@ -485,7 +527,7 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue((RoomID.from_string(room_id), servers)) defer.returnValue((RoomID.from_string(room_id), servers))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_inviter(self, user_id, room_id): def _get_inviter(self, user_id, room_id):
invite = yield self.store.get_invite_for_user_in_room( invite = yield self.store.get_invite_for_user_in_room(
user_id=user_id, user_id=user_id,
room_id=room_id, room_id=room_id,
@ -504,7 +546,7 @@ class RoomMemberHandler(BaseHandler):
requester, requester,
txn_id txn_id
): ):
if self.hs.config.block_non_admin_invites: if self.config.block_non_admin_invites:
is_requester_admin = yield self.auth.is_server_admin( is_requester_admin = yield self.auth.is_server_admin(
requester.user, requester.user,
) )
@ -551,7 +593,7 @@ class RoomMemberHandler(BaseHandler):
str: the matrix ID of the 3pid, or None if it is not recognized. str: the matrix ID of the 3pid, or None if it is not recognized.
""" """
try: try:
data = yield self.hs.get_simple_http_client().get_json( data = yield self.simple_http_client.get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,), "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
{ {
"medium": medium, "medium": medium,
@ -562,7 +604,7 @@ class RoomMemberHandler(BaseHandler):
if "mxid" in data: if "mxid" in data:
if "signatures" not in data: if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding") raise AuthError(401, "No signatures on 3pid binding")
self.verify_any_signature(data, id_server) yield self._verify_any_signature(data, id_server)
defer.returnValue(data["mxid"]) defer.returnValue(data["mxid"])
except IOError as e: except IOError as e:
@ -570,11 +612,11 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def verify_any_signature(self, data, server_hostname): def _verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]: if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,)) raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items(): for key_name, signature in data["signatures"][server_hostname].items():
key_data = yield self.hs.get_simple_http_client().get_json( key_data = yield self.simple_http_client.get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s" % "%s%s/_matrix/identity/api/v1/pubkey/%s" %
(id_server_scheme, server_hostname, key_name,), (id_server_scheme, server_hostname, key_name,),
) )
@ -599,7 +641,7 @@ class RoomMemberHandler(BaseHandler):
user, user,
txn_id txn_id
): ):
room_state = yield self.hs.get_state_handler().get_current_state(room_id) room_state = yield self.state_handler.get_current_state(room_id)
inviter_display_name = "" inviter_display_name = ""
inviter_avatar_url = "" inviter_avatar_url = ""
@ -644,8 +686,7 @@ class RoomMemberHandler(BaseHandler):
) )
) )
msg_handler = self.hs.get_handlers().message_handler yield self.event_creation_hander.create_and_send_nonmember_event(
yield msg_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.ThirdPartyInvite, "type": EventTypes.ThirdPartyInvite,
@ -724,24 +765,20 @@ class RoomMemberHandler(BaseHandler):
"sender_avatar_url": inviter_avatar_url, "sender_avatar_url": inviter_avatar_url,
} }
if self.hs.config.invite_3pid_guest: if self.config.invite_3pid_guest:
registration_handler = self.hs.get_handlers().registration_handler rh = self.registration_handler
guest_access_token = yield registration_handler.guest_access_token_for( guest_user_id, guest_access_token = yield rh.get_or_register_3pid_guest(
medium=medium, medium=medium,
address=address, address=address,
inviter_user_id=inviter_user_id, inviter_user_id=inviter_user_id,
) )
guest_user_info = yield self.hs.get_auth().get_user_by_access_token(
guest_access_token
)
invite_config.update({ invite_config.update({
"guest_access_token": guest_access_token, "guest_access_token": guest_access_token,
"guest_user_id": guest_user_info["user"].to_string(), "guest_user_id": guest_user_id,
}) })
data = yield self.hs.get_simple_http_client().post_urlencoded_get_json( data = yield self.simple_http_client.post_urlencoded_get_json(
is_url, is_url,
invite_config invite_config
) )

View file

@ -235,10 +235,10 @@ class SyncHandler(object):
defer.returnValue(rules) defer.returnValue(rules)
@defer.inlineCallbacks @defer.inlineCallbacks
def ephemeral_by_room(self, sync_config, now_token, since_token=None): def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
"""Get the ephemeral events for each room the user is in """Get the ephemeral events for each room the user is in
Args: Args:
sync_config (SyncConfig): The flags, filters and user for the sync. sync_result_builder(SyncResultBuilder)
now_token (StreamToken): Where the server is currently up to. now_token (StreamToken): Where the server is currently up to.
since_token (StreamToken): Where the server was when the client since_token (StreamToken): Where the server was when the client
last synced. last synced.
@ -248,10 +248,12 @@ class SyncHandler(object):
typing events for that room. typing events for that room.
""" """
sync_config = sync_result_builder.sync_config
with Measure(self.clock, "ephemeral_by_room"): with Measure(self.clock, "ephemeral_by_room"):
typing_key = since_token.typing_key if since_token else "0" typing_key = since_token.typing_key if since_token else "0"
room_ids = yield self.store.get_rooms_for_user(sync_config.user.to_string()) room_ids = sync_result_builder.joined_room_ids
typing_source = self.event_sources.sources["typing"] typing_source = self.event_sources.sources["typing"]
typing, typing_key = yield typing_source.get_new_events( typing, typing_key = yield typing_source.get_new_events(
@ -565,10 +567,22 @@ class SyncHandler(object):
# Always use the `now_token` in `SyncResultBuilder` # Always use the `now_token` in `SyncResultBuilder`
now_token = yield self.event_sources.get_current_token() now_token = yield self.event_sources.get_current_token()
user_id = sync_config.user.to_string()
app_service = self.store.get_app_service_by_user_id(user_id)
if app_service:
# We no longer support AS users using /sync directly.
# See https://github.com/matrix-org/matrix-doc/issues/1144
raise NotImplementedError()
else:
joined_room_ids = yield self.get_rooms_for_user_at(
user_id, now_token.room_stream_id,
)
sync_result_builder = SyncResultBuilder( sync_result_builder = SyncResultBuilder(
sync_config, full_state, sync_config, full_state,
since_token=since_token, since_token=since_token,
now_token=now_token, now_token=now_token,
joined_room_ids=joined_room_ids,
) )
account_data_by_room = yield self._generate_sync_entry_for_account_data( account_data_by_room = yield self._generate_sync_entry_for_account_data(
@ -603,7 +617,6 @@ class SyncHandler(object):
device_id = sync_config.device_id device_id = sync_config.device_id
one_time_key_counts = {} one_time_key_counts = {}
if device_id: if device_id:
user_id = sync_config.user.to_string()
one_time_key_counts = yield self.store.count_e2e_one_time_keys( one_time_key_counts = yield self.store.count_e2e_one_time_keys(
user_id, device_id user_id, device_id
) )
@ -891,7 +904,7 @@ class SyncHandler(object):
ephemeral_by_room = {} ephemeral_by_room = {}
else: else:
now_token, ephemeral_by_room = yield self.ephemeral_by_room( now_token, ephemeral_by_room = yield self.ephemeral_by_room(
sync_result_builder.sync_config, sync_result_builder,
now_token=sync_result_builder.now_token, now_token=sync_result_builder.now_token,
since_token=sync_result_builder.since_token, since_token=sync_result_builder.since_token,
) )
@ -996,15 +1009,8 @@ class SyncHandler(object):
if rooms_changed: if rooms_changed:
defer.returnValue(True) defer.returnValue(True)
app_service = self.store.get_app_service_by_user_id(user_id)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
joined_room_ids = set(r.room_id for r in rooms)
else:
joined_room_ids = yield self.store.get_rooms_for_user(user_id)
stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
for room_id in joined_room_ids: for room_id in sync_result_builder.joined_room_ids:
if self.store.has_room_changed_since(room_id, stream_id): if self.store.has_room_changed_since(room_id, stream_id):
defer.returnValue(True) defer.returnValue(True)
defer.returnValue(False) defer.returnValue(False)
@ -1028,13 +1034,6 @@ class SyncHandler(object):
assert since_token assert since_token
app_service = self.store.get_app_service_by_user_id(user_id)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
joined_room_ids = set(r.room_id for r in rooms)
else:
joined_room_ids = yield self.store.get_rooms_for_user(user_id)
# Get a list of membership change events that have happened. # Get a list of membership change events that have happened.
rooms_changed = yield self.store.get_membership_changes_for_user( rooms_changed = yield self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key user_id, since_token.room_key, now_token.room_key
@ -1057,7 +1056,7 @@ class SyncHandler(object):
# we do send down the room, and with full state, where necessary # we do send down the room, and with full state, where necessary
old_state_ids = None old_state_ids = None
if room_id in joined_room_ids and non_joins: if room_id in sync_result_builder.joined_room_ids and non_joins:
# Always include if the user (re)joined the room, especially # Always include if the user (re)joined the room, especially
# important so that device list changes are calculated correctly. # important so that device list changes are calculated correctly.
# If there are non join member events, but we are still in the room, # If there are non join member events, but we are still in the room,
@ -1067,7 +1066,7 @@ class SyncHandler(object):
# User is in the room so we don't need to do the invite/leave checks # User is in the room so we don't need to do the invite/leave checks
continue continue
if room_id in joined_room_ids or has_join: if room_id in sync_result_builder.joined_room_ids or has_join:
old_state_ids = yield self.get_state_at(room_id, since_token) old_state_ids = yield self.get_state_at(room_id, since_token)
old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
old_mem_ev = None old_mem_ev = None
@ -1079,7 +1078,7 @@ class SyncHandler(object):
newly_joined_rooms.append(room_id) newly_joined_rooms.append(room_id)
# If user is in the room then we don't need to do the invite/leave checks # If user is in the room then we don't need to do the invite/leave checks
if room_id in joined_room_ids: if room_id in sync_result_builder.joined_room_ids:
continue continue
if not non_joins: if not non_joins:
@ -1146,7 +1145,7 @@ class SyncHandler(object):
# Get all events for rooms we're currently joined to. # Get all events for rooms we're currently joined to.
room_to_events = yield self.store.get_room_events_stream_for_rooms( room_to_events = yield self.store.get_room_events_stream_for_rooms(
room_ids=joined_room_ids, room_ids=sync_result_builder.joined_room_ids,
from_key=since_token.room_key, from_key=since_token.room_key,
to_key=now_token.room_key, to_key=now_token.room_key,
limit=timeline_limit + 1, limit=timeline_limit + 1,
@ -1154,7 +1153,7 @@ class SyncHandler(object):
# We loop through all room ids, even if there are no new events, in case # We loop through all room ids, even if there are no new events, in case
# there are non room events taht we need to notify about. # there are non room events taht we need to notify about.
for room_id in joined_room_ids: for room_id in sync_result_builder.joined_room_ids:
room_entry = room_to_events.get(room_id, None) room_entry = room_to_events.get(room_id, None)
if room_entry: if room_entry:
@ -1362,6 +1361,54 @@ class SyncHandler(object):
else: else:
raise Exception("Unrecognized rtype: %r", room_builder.rtype) raise Exception("Unrecognized rtype: %r", room_builder.rtype)
@defer.inlineCallbacks
def get_rooms_for_user_at(self, user_id, stream_ordering):
"""Get set of joined rooms for a user at the given stream ordering.
The stream ordering *must* be recent, otherwise this may throw an
exception if older than a month. (This function is called with the
current token, which should be perfectly fine).
Args:
user_id (str)
stream_ordering (int)
ReturnValue:
Deferred[frozenset[str]]: Set of room_ids the user is in at given
stream_ordering.
"""
joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(
user_id,
)
joined_room_ids = set()
# We need to check that the stream ordering of the join for each room
# is before the stream_ordering asked for. This might not be the case
# if the user joins a room between us getting the current token and
# calling `get_rooms_for_user_with_stream_ordering`.
# If the membership's stream ordering is after the given stream
# ordering, we need to go and work out if the user was in the room
# before.
for room_id, membership_stream_ordering in joined_rooms:
if membership_stream_ordering <= stream_ordering:
joined_room_ids.add(room_id)
continue
logger.info("User joined room after current token: %s", room_id)
extrems = yield self.store.get_forward_extremeties_for_room(
room_id, stream_ordering,
)
users_in_room = yield self.state.get_current_user_in_room(
room_id, extrems,
)
if user_id in users_in_room:
joined_room_ids.add(room_id)
joined_room_ids = frozenset(joined_room_ids)
defer.returnValue(joined_room_ids)
def _action_has_highlight(actions): def _action_has_highlight(actions):
for action in actions: for action in actions:
@ -1411,7 +1458,8 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
class SyncResultBuilder(object): class SyncResultBuilder(object):
"Used to help build up a new SyncResult for a user" "Used to help build up a new SyncResult for a user"
def __init__(self, sync_config, full_state, since_token, now_token): def __init__(self, sync_config, full_state, since_token, now_token,
joined_room_ids):
""" """
Args: Args:
sync_config(SyncConfig) sync_config(SyncConfig)
@ -1423,6 +1471,7 @@ class SyncResultBuilder(object):
self.full_state = full_state self.full_state = full_state
self.since_token = since_token self.since_token = since_token
self.now_token = now_token self.now_token = now_token
self.joined_room_ids = joined_room_ids
self.presence = [] self.presence = []
self.account_data = [] self.account_data = []

View file

@ -56,7 +56,7 @@ class TypingHandler(object):
self.federation = hs.get_federation_sender() self.federation = hs.get_federation_sender()
hs.get_replication_layer().register_edu_handler("m.typing", self._recv_edu) hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
hs.get_distributor().observe("user_left_room", self.user_left_room) hs.get_distributor().observe("user_left_room", self.user_left_room)

View file

@ -18,6 +18,7 @@ from OpenSSL.SSL import VERIFY_NONE
from synapse.api.errors import ( from synapse.api.errors import (
CodeMessageException, MatrixCodeMessageException, SynapseError, Codes, CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
) )
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.logcontext import make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable
from synapse.util import logcontext from synapse.util import logcontext
import synapse.metrics import synapse.metrics
@ -30,6 +31,7 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web.client import ( from twisted.web.client import (
BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent, BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
readBody, PartialDownloadError, readBody, PartialDownloadError,
HTTPConnectionPool,
) )
from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
from twisted.web.http import PotentialDataLoss from twisted.web.http import PotentialDataLoss
@ -64,13 +66,23 @@ class SimpleHttpClient(object):
""" """
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
pool = HTTPConnectionPool(reactor)
# the pusher makes lots of concurrent SSL connections to sygnal, and
# tends to do so in batches, so we need to allow the pool to keep lots
# of idle connections around.
pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
pool.cachedConnectionTimeout = 2 * 60
# The default context factory in Twisted 14.0.0 (which we require) is # The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation # BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser' # 'like a browser'
self.agent = Agent( self.agent = Agent(
reactor, reactor,
connectTimeout=15, connectTimeout=15,
contextFactory=hs.get_http_client_context_factory() contextFactory=hs.get_http_client_context_factory(),
pool=pool,
) )
self.user_agent = hs.version_string self.user_agent = hs.version_string
self.clock = hs.get_clock() self.clock = hs.get_clock()

View file

@ -357,8 +357,7 @@ def _get_hosts_for_srv_record(dns_client, host):
def eb(res, record_type): def eb(res, record_type):
if res.check(DNSNameError): if res.check(DNSNameError):
return [] return []
logger.warn("Error looking up %s for %s: %s", logger.warn("Error looking up %s for %s: %s", record_type, host, res)
record_type, host, res, res.value)
return res return res
# no logcontexts here, so we can safely fire these off and gatherResults # no logcontexts here, so we can safely fire these off and gatherResults

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -59,6 +60,11 @@ response_count = metrics.register_counter(
) )
) )
requests_counter = metrics.register_counter(
"requests_received",
labels=["method", "servlet", ],
)
outgoing_responses_counter = metrics.register_counter( outgoing_responses_counter = metrics.register_counter(
"responses", "responses",
labels=["method", "code"], labels=["method", "code"],
@ -145,7 +151,8 @@ def wrap_request_handler(request_handler, include_metrics=False):
# at the servlet name. For most requests that name will be # at the servlet name. For most requests that name will be
# JsonResource (or a subclass), and JsonResource._async_render # JsonResource (or a subclass), and JsonResource._async_render
# will update it once it picks a servlet. # will update it once it picks a servlet.
request_metrics.start(self.clock, name=self.__class__.__name__) servlet_name = self.__class__.__name__
request_metrics.start(self.clock, name=servlet_name)
request_context.request = request_id request_context.request = request_id
with request.processing(): with request.processing():
@ -154,6 +161,7 @@ def wrap_request_handler(request_handler, include_metrics=False):
if include_metrics: if include_metrics:
yield request_handler(self, request, request_metrics) yield request_handler(self, request, request_metrics)
else: else:
requests_counter.inc(request.method, servlet_name)
yield request_handler(self, request) yield request_handler(self, request)
except CodeMessageException as e: except CodeMessageException as e:
code = e.code code = e.code
@ -229,7 +237,7 @@ class JsonResource(HttpServer, resource.Resource):
""" This implements the HttpServer interface and provides JSON support for """ This implements the HttpServer interface and provides JSON support for
Resources. Resources.
Register callbacks via register_path() Register callbacks via register_paths()
Callbacks can return a tuple of status code and a dict in which case the Callbacks can return a tuple of status code and a dict in which case the
the dict will automatically be sent to the client as a JSON object. the dict will automatically be sent to the client as a JSON object.
@ -276,21 +284,7 @@ class JsonResource(HttpServer, resource.Resource):
This checks if anyone has registered a callback for that method and This checks if anyone has registered a callback for that method and
path. path.
""" """
if request.method == "OPTIONS": callback, group_dict = self._get_handler_for_request(request)
self._send_response(request, 200, {})
return
# Loop through all the registered callbacks to check if the method
# and path regex match
for path_entry in self.path_regexs.get(request.method, []):
m = path_entry.pattern.match(request.path)
if not m:
continue
# We found a match! First update the metrics object to indicate
# which servlet is handling the request.
callback = path_entry.callback
servlet_instance = getattr(callback, "__self__", None) servlet_instance = getattr(callback, "__self__", None)
if servlet_instance is not None: if servlet_instance is not None:
@ -299,6 +293,7 @@ class JsonResource(HttpServer, resource.Resource):
servlet_classname = "%r" % callback servlet_classname = "%r" % callback
request_metrics.name = servlet_classname request_metrics.name = servlet_classname
requests_counter.inc(request.method, servlet_classname)
# Now trigger the callback. If it returns a response, we send it # Now trigger the callback. If it returns a response, we send it
# here. If it throws an exception, that is handled by the wrapper # here. If it throws an exception, that is handled by the wrapper
@ -306,7 +301,7 @@ class JsonResource(HttpServer, resource.Resource):
kwargs = intern_dict({ kwargs = intern_dict({
name: urllib.unquote(value).decode("UTF-8") if value else value name: urllib.unquote(value).decode("UTF-8") if value else value
for name, value in m.groupdict().items() for name, value in group_dict.items()
}) })
callback_return = yield callback(request, **kwargs) callback_return = yield callback(request, **kwargs)
@ -314,11 +309,34 @@ class JsonResource(HttpServer, resource.Resource):
code, response = callback_return code, response = callback_return
self._send_response(request, code, response) self._send_response(request, code, response)
return def _get_handler_for_request(self, request):
"""Finds a callback method to handle the given request
Args:
request (twisted.web.http.Request):
Returns:
Tuple[Callable, dict[str, str]]: callback method, and the dict
mapping keys to path components as specified in the handler's
path match regexp.
The callback will normally be a method registered via
register_paths, so will return (possibly via Deferred) either
None, or a tuple of (http code, response body).
"""
if request.method == "OPTIONS":
return _options_handler, {}
# Loop through all the registered callbacks to check if the method
# and path regex match
for path_entry in self.path_regexs.get(request.method, []):
m = path_entry.pattern.match(request.path)
if m:
# We found a match!
return path_entry.callback, m.groupdict()
# Huh. No one wanted to handle that? Fiiiiiine. Send 400. # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
request_metrics.name = self.__class__.__name__ + ".UnrecognizedRequest" return _unrecognised_request_handler, {}
raise UnrecognizedRequestError()
def _send_response(self, request, code, response_json_object, def _send_response(self, request, code, response_json_object,
response_code_message=None): response_code_message=None):
@ -335,6 +353,34 @@ class JsonResource(HttpServer, resource.Resource):
) )
def _options_handler(request):
"""Request handler for OPTIONS requests
This is a request handler suitable for return from
_get_handler_for_request. It returns a 200 and an empty body.
Args:
request (twisted.web.http.Request):
Returns:
Tuple[int, dict]: http code, response body.
"""
return 200, {}
def _unrecognised_request_handler(request):
"""Request handler for unrecognised requests
This is a request handler suitable for return from
_get_handler_for_request. It actually just raises an
UnrecognizedRequestError.
Args:
request (twisted.web.http.Request):
"""
raise UnrecognizedRequestError()
class RequestMetrics(object): class RequestMetrics(object):
def start(self, clock, name): def start(self, clock, name):
self.start = clock.time_msec() self.start = clock.time_msec()

View file

@ -148,11 +148,13 @@ def parse_string_from_args(args, name, default=None, required=False,
return default return default
def parse_json_value_from_request(request): def parse_json_value_from_request(request, allow_empty_body=False):
"""Parse a JSON value from the body of a twisted HTTP request. """Parse a JSON value from the body of a twisted HTTP request.
Args: Args:
request: the twisted HTTP request. request: the twisted HTTP request.
allow_empty_body (bool): if True, an empty body will be accepted and
turned into None
Returns: Returns:
The JSON value. The JSON value.
@ -165,6 +167,9 @@ def parse_json_value_from_request(request):
except Exception: except Exception:
raise SynapseError(400, "Error reading JSON content.") raise SynapseError(400, "Error reading JSON content.")
if not content_bytes and allow_empty_body:
return None
try: try:
content = simplejson.loads(content_bytes) content = simplejson.loads(content_bytes)
except Exception as e: except Exception as e:
@ -174,17 +179,24 @@ def parse_json_value_from_request(request):
return content return content
def parse_json_object_from_request(request): def parse_json_object_from_request(request, allow_empty_body=False):
"""Parse a JSON object from the body of a twisted HTTP request. """Parse a JSON object from the body of a twisted HTTP request.
Args: Args:
request: the twisted HTTP request. request: the twisted HTTP request.
allow_empty_body (bool): if True, an empty body will be accepted and
turned into an empty dict.
Raises: Raises:
SynapseError if the request body couldn't be decoded as JSON or SynapseError if the request body couldn't be decoded as JSON or
if it wasn't a JSON object. if it wasn't a JSON object.
""" """
content = parse_json_value_from_request(request) content = parse_json_value_from_request(
request, allow_empty_body=allow_empty_body,
)
if allow_empty_body and content is None:
return {}
if type(content) != dict: if type(content) != dict:
message = "Content must be a JSON object." message = "Content must be a JSON object."

View file

@ -57,15 +57,31 @@ class Metrics(object):
return metric return metric
def register_counter(self, *args, **kwargs): def register_counter(self, *args, **kwargs):
"""
Returns:
CounterMetric
"""
return self._register(CounterMetric, *args, **kwargs) return self._register(CounterMetric, *args, **kwargs)
def register_callback(self, *args, **kwargs): def register_callback(self, *args, **kwargs):
"""
Returns:
CallbackMetric
"""
return self._register(CallbackMetric, *args, **kwargs) return self._register(CallbackMetric, *args, **kwargs)
def register_distribution(self, *args, **kwargs): def register_distribution(self, *args, **kwargs):
"""
Returns:
DistributionMetric
"""
return self._register(DistributionMetric, *args, **kwargs) return self._register(DistributionMetric, *args, **kwargs)
def register_cache(self, *args, **kwargs): def register_cache(self, *args, **kwargs):
"""
Returns:
CacheMetric
"""
return self._register(CacheMetric, *args, **kwargs) return self._register(CacheMetric, *args, **kwargs)
@ -146,10 +162,15 @@ def runUntilCurrentTimer(func):
num_pending += 1 num_pending += 1
num_pending += len(reactor.threadCallQueue) num_pending += len(reactor.threadCallQueue)
start = time.time() * 1000 start = time.time() * 1000
ret = func(*args, **kwargs) ret = func(*args, **kwargs)
end = time.time() * 1000 end = time.time() * 1000
# record the amount of wallclock time spent running pending calls.
# This is a proxy for the actual amount of time between reactor polls,
# since about 25% of time is actually spent running things triggered by
# I/O events, but that is harder to capture without rewriting half the
# reactor.
tick_time.inc_by(end - start) tick_time.inc_by(end - start)
pending_calls_metric.inc_by(num_pending) pending_calls_metric.inc_by(num_pending)

View file

@ -193,7 +193,9 @@ class DistributionMetric(object):
class CacheMetric(object): class CacheMetric(object):
__slots__ = ("name", "cache_name", "hits", "misses", "size_callback") __slots__ = (
"name", "cache_name", "hits", "misses", "evicted_size", "size_callback",
)
def __init__(self, name, size_callback, cache_name): def __init__(self, name, size_callback, cache_name):
self.name = name self.name = name
@ -201,6 +203,7 @@ class CacheMetric(object):
self.hits = 0 self.hits = 0
self.misses = 0 self.misses = 0
self.evicted_size = 0
self.size_callback = size_callback self.size_callback = size_callback
@ -210,6 +213,9 @@ class CacheMetric(object):
def inc_misses(self): def inc_misses(self):
self.misses += 1 self.misses += 1
def inc_evictions(self, size=1):
self.evicted_size += size
def render(self): def render(self):
size = self.size_callback() size = self.size_callback()
hits = self.hits hits = self.hits
@ -219,6 +225,9 @@ class CacheMetric(object):
"""%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits), """%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits),
"""%s:total{name="%s"} %d""" % (self.name, self.cache_name, total), """%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
"""%s:size{name="%s"} %d""" % (self.name, self.cache_name, size), """%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
"""%s:evicted_size{name="%s"} %d""" % (
self.name, self.cache_name, self.evicted_size
),
] ]

View file

@ -40,10 +40,6 @@ class ActionGenerator(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context): def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "action_for_event_by_user"): with Measure(self.clock, "action_for_event_by_user"):
actions_by_user = yield self.bulk_evaluator.action_for_event_by_user( yield self.bulk_evaluator.action_for_event_by_user(
event, context event, context
) )
context.push_actions = [
(uid, actions) for uid, actions in actions_by_user.iteritems()
]

View file

@ -137,11 +137,11 @@ class BulkPushRuleEvaluator(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def action_for_event_by_user(self, event, context): def action_for_event_by_user(self, event, context):
"""Given an event and context, evaluate the push rules and return """Given an event and context, evaluate the push rules and insert the
the results results into the event_push_actions_staging table.
Returns: Returns:
dict of user_id -> action Deferred
""" """
rules_by_user = yield self._get_rules_for_event(event, context) rules_by_user = yield self._get_rules_for_event(event, context)
actions_by_user = {} actions_by_user = {}
@ -190,9 +190,16 @@ class BulkPushRuleEvaluator(object):
if matches: if matches:
actions = [x for x in rule['actions'] if x != 'dont_notify'] actions = [x for x in rule['actions'] if x != 'dont_notify']
if actions and 'notify' in actions: if actions and 'notify' in actions:
# Push rules say we should notify the user of this event
actions_by_user[uid] = actions actions_by_user[uid] = actions
break break
defer.returnValue(actions_by_user)
# Mark in the DB staging area the push actions for users who should be
# notified for this event. (This will then get handled when we persist
# the event)
yield self.store.add_push_actions_to_staging(
event.event_id, actions_by_user,
)
def _condition_checker(evaluator, conditions, uid, display_name, cache): def _condition_checker(evaluator, conditions, uid, display_name, cache):

View file

@ -13,21 +13,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from synapse.push import PusherConfigException
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.error import AlreadyCalled, AlreadyCancelled from twisted.internet.error import AlreadyCalled, AlreadyCancelled
import logging
import push_rule_evaluator import push_rule_evaluator
import push_tools import push_tools
import synapse
from synapse.push import PusherConfigException
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
http_push_processed_counter = metrics.register_counter(
"http_pushes_processed",
)
http_push_failed_counter = metrics.register_counter(
"http_pushes_failed",
)
class HttpPusher(object): class HttpPusher(object):
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
@ -152,9 +161,16 @@ class HttpPusher(object):
self.user_id, self.last_stream_ordering, self.max_stream_ordering self.user_id, self.last_stream_ordering, self.max_stream_ordering
) )
logger.info(
"Processing %i unprocessed push actions for %s starting at "
"stream_ordering %s",
len(unprocessed), self.name, self.last_stream_ordering,
)
for push_action in unprocessed: for push_action in unprocessed:
processed = yield self._process_one(push_action) processed = yield self._process_one(push_action)
if processed: if processed:
http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action['stream_ordering'] self.last_stream_ordering = push_action['stream_ordering']
yield self.store.update_pusher_last_stream_ordering_and_success( yield self.store.update_pusher_last_stream_ordering_and_success(
@ -169,6 +185,7 @@ class HttpPusher(object):
self.failing_since self.failing_since
) )
else: else:
http_push_failed_counter.inc()
if not self.failing_since: if not self.failing_since:
self.failing_since = self.clock.time_msec() self.failing_since = self.clock.time_msec()
yield self.store.update_pusher_failing_since( yield self.store.update_pusher_failing_since(
@ -316,7 +333,10 @@ class HttpPusher(object):
try: try:
resp = yield self.http_client.post_json_get_json(self.url, notification_dict) resp = yield self.http_client.post_json_get_json(self.url, notification_dict)
except Exception: except Exception:
logger.warn("Failed to push %s ", self.url) logger.warn(
"Failed to push event %s to %s",
event.event_id, self.name, exc_info=True,
)
defer.returnValue(False) defer.returnValue(False)
rejected = [] rejected = []
if 'rejected' in resp: if 'rejected' in resp:
@ -325,7 +345,7 @@ class HttpPusher(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _send_badge(self, badge): def _send_badge(self, badge):
logger.info("Sending updated badge count %d to %r", badge, self.user_id) logger.info("Sending updated badge count %d to %s", badge, self.name)
d = { d = {
'notification': { 'notification': {
'id': '', 'id': '',
@ -347,7 +367,10 @@ class HttpPusher(object):
try: try:
resp = yield self.http_client.post_json_get_json(self.url, d) resp = yield self.http_client.post_json_get_json(self.url, d)
except Exception: except Exception:
logger.exception("Failed to push %s ", self.url) logger.warn(
"Failed to send badge count to %s",
self.name, exc_info=True,
)
defer.returnValue(False) defer.returnValue(False)
rejected = [] rejected = []
if 'rejected' in resp: if 'rejected' in resp:

View file

@ -24,19 +24,19 @@ REQUIREMENTS = {
"unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"], "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
"canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"], "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
"signedjson>=1.0.0": ["signedjson>=1.0.0"], "signedjson>=1.0.0": ["signedjson>=1.0.0"],
"pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"], "pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"],
"Twisted>=16.0.0": ["twisted>=16.0.0"], "Twisted>=16.0.0": ["twisted>=16.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyopenssl>=0.14": ["OpenSSL>=0.14"],
"pyyaml": ["yaml"], "pyyaml": ["yaml"],
"pyasn1": ["pyasn1"], "pyasn1": ["pyasn1"],
"daemonize": ["daemonize"], "daemonize": ["daemonize"],
"bcrypt": ["bcrypt"], "bcrypt": ["bcrypt>=3.1.0"],
"pillow": ["PIL"], "pillow": ["PIL"],
"pydenticon": ["pydenticon"], "pydenticon": ["pydenticon"],
"ujson": ["ujson"], "ujson": ["ujson"],
"blist": ["blist"], "blist": ["blist"],
"pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"], "pysaml2>=3.0.0": ["saml2>=3.0.0"],
"pymacaroons-pynacl": ["pymacaroons"], "pymacaroons-pynacl": ["pymacaroons"],
"msgpack-python>=0.3.0": ["msgpack"], "msgpack-python>=0.3.0": ["msgpack"],
"phonenumbers>=8.2.0": ["phonenumbers"], "phonenumbers>=8.2.0": ["phonenumbers"],

View file

@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import send_event
from synapse.http.server import JsonResource
REPLICATION_PREFIX = "/_synapse/replication"
class ReplicationRestResource(JsonResource):
def __init__(self, hs):
JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(hs)
def register_servlets(self, hs):
send_event.register_servlets(hs, self)

View file

@ -0,0 +1,166 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import (
SynapseError, MatrixCodeMessageException, CodeMessageException,
)
from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.util.async import sleep
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.metrics import Measure
from synapse.types import Requester, UserID
import logging
import re
logger = logging.getLogger(__name__)
@defer.inlineCallbacks
def send_event_to_master(client, host, port, requester, event, context,
ratelimit, extra_users):
"""Send event to be handled on the master
Args:
client (SimpleHttpClient)
host (str): host of master
port (int): port on master listening for HTTP replication
requester (Requester)
event (FrozenEvent)
context (EventContext)
ratelimit (bool)
extra_users (list(UserID)): Any extra users to notify about event
"""
uri = "http://%s:%s/_synapse/replication/send_event/%s" % (
host, port, event.event_id,
)
payload = {
"event": event.get_pdu_json(),
"internal_metadata": event.internal_metadata.get_dict(),
"rejected_reason": event.rejected_reason,
"context": context.serialize(event),
"requester": requester.serialize(),
"ratelimit": ratelimit,
"extra_users": [u.to_string() for u in extra_users],
}
try:
# We keep retrying the same request for timeouts. This is so that we
# have a good idea that the request has either succeeded or failed on
# the master, and so whether we should clean up or not.
while True:
try:
result = yield client.put_json(uri, payload)
break
except CodeMessageException as e:
if e.code != 504:
raise
logger.warn("send_event request timed out")
# If we timed out we probably don't need to worry about backing
# off too much, but lets just wait a little anyway.
yield sleep(1)
except MatrixCodeMessageException as e:
# We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And
# importantly, not stack traces everywhere)
raise SynapseError(e.code, e.msg, e.errcode)
defer.returnValue(result)
class ReplicationSendEventRestServlet(RestServlet):
"""Handles events newly created on workers, including persisting and
notifying.
The API looks like:
POST /_synapse/replication/send_event/:event_id
{
"event": { .. serialized event .. },
"internal_metadata": { .. serialized internal_metadata .. },
"rejected_reason": .., // The event.rejected_reason field
"context": { .. serialized event context .. },
"requester": { .. serialized requester .. },
"ratelimit": true,
"extra_users": [],
}
"""
PATTERNS = [re.compile("^/_synapse/replication/send_event/(?P<event_id>[^/]+)$")]
def __init__(self, hs):
super(ReplicationSendEventRestServlet, self).__init__()
self.event_creation_handler = hs.get_event_creation_handler()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
# The responses are tiny, so we may as well cache them for a while
self.response_cache = ResponseCache(hs, timeout_ms=30 * 60 * 1000)
def on_PUT(self, request, event_id):
result = self.response_cache.get(event_id)
if not result:
result = self.response_cache.set(
event_id,
self._handle_request(request)
)
else:
logger.warn("Returning cached response")
return make_deferred_yieldable(result)
@preserve_fn
@defer.inlineCallbacks
def _handle_request(self, request):
with Measure(self.clock, "repl_send_event_parse"):
content = parse_json_object_from_request(request)
event_dict = content["event"]
internal_metadata = content["internal_metadata"]
rejected_reason = content["rejected_reason"]
event = FrozenEvent(event_dict, internal_metadata, rejected_reason)
requester = Requester.deserialize(self.store, content["requester"])
context = yield EventContext.deserialize(self.store, content["context"])
ratelimit = content["ratelimit"]
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
if requester.user:
request.authenticated_entity = requester.user.to_string()
logger.info(
"Got event to send with ID: %s into room: %s",
event.event_id, event.room_id,
)
yield self.event_creation_handler.persist_and_notify_client_event(
requester, event, context,
ratelimit=ratelimit,
extra_users=extra_users,
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
ReplicationSendEventRestServlet(hs).register(http_server)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,50 +14,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore from synapse.storage.account_data import AccountDataWorkerStore
from synapse.storage.account_data import AccountDataStore from synapse.storage.tags import TagsWorkerStore
from synapse.storage.tags import TagsStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedAccountDataStore(BaseSlavedStore): class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SlavedAccountDataStore, self).__init__(db_conn, hs)
self._account_data_id_gen = SlavedIdTracker( self._account_data_id_gen = SlavedIdTracker(
db_conn, "account_data_max_stream_id", "stream_id", db_conn, "account_data_max_stream_id", "stream_id",
) )
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache",
self._account_data_id_gen.get_current_token(),
)
get_account_data_for_user = ( super(SlavedAccountDataStore, self).__init__(db_conn, hs)
AccountDataStore.__dict__["get_account_data_for_user"]
)
get_global_account_data_by_type_for_users = (
AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
)
get_global_account_data_by_type_for_user = (
AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
)
get_tags_for_user = TagsStore.__dict__["get_tags_for_user"]
get_tags_for_room = (
DataStore.get_tags_for_room.__func__
)
get_account_data_for_room = (
DataStore.get_account_data_for_room.__func__
)
get_updated_tags = DataStore.get_updated_tags.__func__
get_updated_account_data_for_user = (
DataStore.get_updated_account_data_for_user.__func__
)
def get_max_account_data_stream_id(self): def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token() return self._account_data_id_gen.get_current_token()
@ -85,6 +56,10 @@ class SlavedAccountDataStore(BaseSlavedStore):
(row.data_type, row.user_id,) (row.data_type, row.user_id,)
) )
self.get_account_data_for_user.invalidate((row.user_id,)) self.get_account_data_for_user.invalidate((row.user_id,))
self.get_account_data_for_room.invalidate((row.user_id, row.room_id,))
self.get_account_data_for_room_and_type.invalidate(
(row.user_id, row.room_id, row.data_type,),
)
self._account_data_stream_cache.entity_has_changed( self._account_data_stream_cache.entity_has_changed(
row.user_id, token row.user_id, token
) )

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,33 +14,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import BaseSlavedStore from synapse.storage.appservice import (
from synapse.storage import DataStore ApplicationServiceWorkerStore, ApplicationServiceTransactionWorkerStore,
from synapse.config.appservice import load_appservices
from synapse.storage.appservice import _make_exclusive_regex
class SlavedApplicationServiceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedApplicationServiceStore, self).__init__(db_conn, hs)
self.services_cache = load_appservices(
hs.config.server_name,
hs.config.app_service_config_files
) )
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
get_app_service_by_token = DataStore.get_app_service_by_token.__func__
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__ class SlavedApplicationServiceStore(ApplicationServiceTransactionWorkerStore,
get_app_services = DataStore.get_app_services.__func__ ApplicationServiceWorkerStore):
get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__ pass
create_appservice_txn = DataStore.create_appservice_txn.__func__
get_appservices_by_state = DataStore.get_appservices_by_state.__func__
get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__
_get_last_txn = DataStore._get_last_txn.__func__
complete_appservice_txn = DataStore.complete_appservice_txn.__func__
get_appservice_state = DataStore.get_appservice_state.__func__
set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__
set_appservice_state = DataStore.set_appservice_state.__func__
get_if_app_services_interested_in_user = (
DataStore.get_if_app_services_interested_in_user.__func__
)

View file

@ -14,10 +14,8 @@
# limitations under the License. # limitations under the License.
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from synapse.storage.directory import DirectoryStore from synapse.storage.directory import DirectoryWorkerStore
class DirectoryStore(BaseSlavedStore): class DirectoryStore(DirectoryWorkerStore, BaseSlavedStore):
get_aliases_for_room = DirectoryStore.__dict__[ pass
"get_aliases_for_room"
]

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -15,13 +16,13 @@
import logging import logging
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.storage import DataStore from synapse.storage.event_federation import EventFederationWorkerStore
from synapse.storage.event_federation import EventFederationStore from synapse.storage.event_push_actions import EventPushActionsWorkerStore
from synapse.storage.event_push_actions import EventPushActionsStore from synapse.storage.events_worker import EventsWorkerStore
from synapse.storage.roommember import RoomMemberStore from synapse.storage.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateGroupReadStore from synapse.storage.state import StateGroupWorkerStore
from synapse.storage.stream import StreamStore from synapse.storage.stream import StreamWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.storage.signatures import SignatureWorkerStore
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
@ -37,138 +38,33 @@ logger = logging.getLogger(__name__)
# the method descriptor on the DataStore and chuck them into our class. # the method descriptor on the DataStore and chuck them into our class.
class SlavedEventStore(StateGroupReadStore, BaseSlavedStore): class SlavedEventStore(EventFederationWorkerStore,
RoomMemberWorkerStore,
EventPushActionsWorkerStore,
StreamWorkerStore,
EventsWorkerStore,
StateGroupWorkerStore,
SignatureWorkerStore,
BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SlavedEventStore, self).__init__(db_conn, hs)
self._stream_id_gen = SlavedIdTracker( self._stream_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", db_conn, "events", "stream_ordering",
) )
self._backfill_id_gen = SlavedIdTracker( self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1 db_conn, "events", "stream_ordering", step=-1
) )
events_max = self._stream_id_gen.get_current_token()
event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events",
entity_column="room_id",
stream_column="stream_ordering",
max_value=events_max,
)
self._events_stream_cache = StreamChangeCache(
"EventsRoomStreamChangeCache", min_event_val,
prefilled_cache=event_cache_prefill,
)
self._membership_stream_cache = StreamChangeCache(
"MembershipStreamChangeCache", events_max,
)
self.stream_ordering_month_ago = 0 super(SlavedEventStore, self).__init__(db_conn, hs)
self._stream_order_on_start = self.get_room_max_stream_ordering()
# Cached functions can't be accessed through a class instance so we need # Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them. # to reach inside the __dict__ to extract them.
get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
get_hosts_in_room = RoomMemberStore.__dict__["get_hosts_in_room"]
get_users_who_share_room_with_user = (
RoomMemberStore.__dict__["get_users_who_share_room_with_user"]
)
get_latest_event_ids_in_room = EventFederationStore.__dict__[
"get_latest_event_ids_in_room"
]
get_invited_rooms_for_user = RoomMemberStore.__dict__[
"get_invited_rooms_for_user"
]
get_unread_event_push_actions_by_room_for_user = (
EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"]
)
_get_unread_counts_by_receipt_txn = (
DataStore._get_unread_counts_by_receipt_txn.__func__
)
_get_unread_counts_by_pos_txn = (
DataStore._get_unread_counts_by_pos_txn.__func__
)
get_recent_event_ids_for_room = (
StreamStore.__dict__["get_recent_event_ids_for_room"]
)
_get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"]
has_room_changed_since = DataStore.has_room_changed_since.__func__
get_unread_push_actions_for_user_in_range_for_http = ( def get_room_max_stream_ordering(self):
DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__ return self._stream_id_gen.get_current_token()
)
get_unread_push_actions_for_user_in_range_for_email = (
DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__
)
get_push_action_users_in_range = (
DataStore.get_push_action_users_in_range.__func__
)
get_event = DataStore.get_event.__func__
get_events = DataStore.get_events.__func__
get_rooms_for_user_where_membership_is = (
DataStore.get_rooms_for_user_where_membership_is.__func__
)
get_membership_changes_for_user = (
DataStore.get_membership_changes_for_user.__func__
)
get_room_events_max_id = DataStore.get_room_events_max_id.__func__
get_room_events_stream_for_room = (
DataStore.get_room_events_stream_for_room.__func__
)
get_events_around = DataStore.get_events_around.__func__
get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__
get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
_get_joined_users_from_context = (
RoomMemberStore.__dict__["_get_joined_users_from_context"]
)
get_joined_hosts = DataStore.get_joined_hosts.__func__ def get_room_min_stream_ordering(self):
_get_joined_hosts = RoomMemberStore.__dict__["_get_joined_hosts"] return self._backfill_id_gen.get_current_token()
get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
get_room_events_stream_for_rooms = (
DataStore.get_room_events_stream_for_rooms.__func__
)
is_host_joined = RoomMemberStore.__dict__["is_host_joined"]
get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
_set_before_and_after = staticmethod(DataStore._set_before_and_after)
_get_events = DataStore._get_events.__func__
_get_events_from_cache = DataStore._get_events_from_cache.__func__
_invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__
_enqueue_events = DataStore._enqueue_events.__func__
_do_fetch = DataStore._do_fetch.__func__
_fetch_event_rows = DataStore._fetch_event_rows.__func__
_get_event_from_row = DataStore._get_event_from_row.__func__
_get_rooms_for_user_where_membership_is_txn = (
DataStore._get_rooms_for_user_where_membership_is_txn.__func__
)
_get_events_around_txn = DataStore._get_events_around_txn.__func__
get_backfill_events = DataStore.get_backfill_events.__func__
_get_backfill_events = DataStore._get_backfill_events.__func__
get_missing_events = DataStore.get_missing_events.__func__
_get_missing_events = DataStore._get_missing_events.__func__
get_auth_chain = DataStore.get_auth_chain.__func__
get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__
_get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__
get_room_max_stream_ordering = DataStore.get_room_max_stream_ordering.__func__
get_forward_extremeties_for_room = (
DataStore.get_forward_extremeties_for_room.__func__
)
_get_forward_extremeties_for_room = (
EventFederationStore.__dict__["_get_forward_extremeties_for_room"]
)
get_all_new_events_stream = DataStore.get_all_new_events_stream.__func__
get_federation_out_pos = DataStore.get_federation_out_pos.__func__
update_federation_out_pos = DataStore.update_federation_out_pos.__func__
def stream_positions(self): def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions() result = super(SlavedEventStore, self).stream_positions()

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -15,29 +16,15 @@
from .events import SlavedEventStore from .events import SlavedEventStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore from synapse.storage.push_rule import PushRulesWorkerStore
from synapse.storage.push_rule import PushRuleStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedPushRuleStore(SlavedEventStore): class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SlavedPushRuleStore, self).__init__(db_conn, hs)
self._push_rules_stream_id_gen = SlavedIdTracker( self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id", db_conn, "push_rules_stream", "stream_id",
) )
self.push_rules_stream_cache = StreamChangeCache( super(SlavedPushRuleStore, self).__init__(db_conn, hs)
"PushRulesStreamChangeCache",
self._push_rules_stream_id_gen.get_current_token(),
)
get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
get_push_rules_enabled_for_user = (
PushRuleStore.__dict__["get_push_rules_enabled_for_user"]
)
have_push_rules_changed_for_user = (
DataStore.have_push_rules_changed_for_user.__func__
)
def get_push_rules_stream_token(self): def get_push_rules_stream_token(self):
return ( return (
@ -45,6 +32,9 @@ class SlavedPushRuleStore(SlavedEventStore):
self._stream_id_gen.get_current_token(), self._stream_id_gen.get_current_token(),
) )
def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()
def stream_positions(self): def stream_positions(self):
result = super(SlavedPushRuleStore, self).stream_positions() result = super(SlavedPushRuleStore, self).stream_positions()
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token() result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,10 +17,10 @@
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore from synapse.storage.pusher import PusherWorkerStore
class SlavedPusherStore(BaseSlavedStore): class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SlavedPusherStore, self).__init__(db_conn, hs) super(SlavedPusherStore, self).__init__(db_conn, hs)
@ -28,13 +29,6 @@ class SlavedPusherStore(BaseSlavedStore):
extra_tables=[("deleted_pushers", "stream_id")], extra_tables=[("deleted_pushers", "stream_id")],
) )
get_all_pushers = DataStore.get_all_pushers.__func__
get_pushers_by = DataStore.get_pushers_by.__func__
get_pushers_by_app_id_and_pushkey = (
DataStore.get_pushers_by_app_id_and_pushkey.__func__
)
_decode_pushers_rows = DataStore._decode_pushers_rows.__func__
def stream_positions(self): def stream_positions(self):
result = super(SlavedPusherStore, self).stream_positions() result = super(SlavedPusherStore, self).stream_positions()
result["pushers"] = self._pushers_id_gen.get_current_token() result["pushers"] = self._pushers_id_gen.get_current_token()

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,9 +17,7 @@
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore from synapse.storage.receipts import ReceiptsWorkerStore
from synapse.storage.receipts import ReceiptsStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
# So, um, we want to borrow a load of functions intended for reading from # So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the # a DataStore, but we don't want to take functions that either write to the
@ -29,36 +28,19 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
# the method descriptor on the DataStore and chuck them into our class. # the method descriptor on the DataStore and chuck them into our class.
class SlavedReceiptsStore(BaseSlavedStore): class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SlavedReceiptsStore, self).__init__(db_conn, hs) # We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = SlavedIdTracker( self._receipts_id_gen = SlavedIdTracker(
db_conn, "receipts_linearized", "stream_id" db_conn, "receipts_linearized", "stream_id"
) )
self._receipts_stream_cache = StreamChangeCache( super(SlavedReceiptsStore, self).__init__(db_conn, hs)
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
)
get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"] def get_max_receipt_stream_id(self):
get_linearized_receipts_for_room = ( return self._receipts_id_gen.get_current_token()
ReceiptsStore.__dict__["get_linearized_receipts_for_room"]
)
_get_linearized_receipts_for_rooms = (
ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"]
)
get_last_receipt_event_id_for_user = (
ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"]
)
get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
get_linearized_receipts_for_rooms = (
DataStore.get_linearized_receipts_for_rooms.__func__
)
def stream_positions(self): def stream_positions(self):
result = super(SlavedReceiptsStore, self).stream_positions() result = super(SlavedReceiptsStore, self).stream_positions()
@ -71,6 +53,8 @@ class SlavedReceiptsStore(BaseSlavedStore):
self.get_last_receipt_event_id_for_user.invalidate( self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type) (user_id, room_id, receipt_type)
) )
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
self.get_receipts_for_room.invalidate((room_id, receipt_type))
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, token, rows):
if stream_name == "receipts": if stream_name == "receipts":

View file

@ -14,20 +14,8 @@
# limitations under the License. # limitations under the License.
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from synapse.storage import DataStore from synapse.storage.registration import RegistrationWorkerStore
from synapse.storage.registration import RegistrationStore
class SlavedRegistrationStore(BaseSlavedStore): class SlavedRegistrationStore(RegistrationWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs): pass
super(SlavedRegistrationStore, self).__init__(db_conn, hs)
# TODO: use the cached version and invalidate deleted tokens
get_user_by_access_token = RegistrationStore.__dict__[
"get_user_by_access_token"
]
_query_for_auth = DataStore._query_for_auth.__func__
get_user_by_id = RegistrationStore.__dict__[
"get_user_by_id"
]

View file

@ -14,32 +14,19 @@
# limitations under the License. # limitations under the License.
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from synapse.storage import DataStore from synapse.storage.room import RoomWorkerStore
from synapse.storage.room import RoomStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
class RoomStore(BaseSlavedStore): class RoomStore(RoomWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(RoomStore, self).__init__(db_conn, hs) super(RoomStore, self).__init__(db_conn, hs)
self._public_room_id_gen = SlavedIdTracker( self._public_room_id_gen = SlavedIdTracker(
db_conn, "public_room_list_stream", "stream_id" db_conn, "public_room_list_stream", "stream_id"
) )
get_public_room_ids = DataStore.get_public_room_ids.__func__ def get_current_public_room_stream_id(self):
get_current_public_room_stream_id = ( return self._public_room_id_gen.get_current_token()
DataStore.get_current_public_room_stream_id.__func__
)
get_public_room_ids_at_stream_id = (
RoomStore.__dict__["get_public_room_ids_at_stream_id"]
)
get_public_room_ids_at_stream_id_txn = (
DataStore.get_public_room_ids_at_stream_id_txn.__func__
)
get_published_at_stream_id_txn = (
DataStore.get_published_at_stream_id_txn.__func__
)
get_public_room_changes = DataStore.get_public_room_changes.__func__
def stream_positions(self): def stream_positions(self):
result = super(RoomStore, self).stream_positions() result = super(RoomStore, self).stream_positions()

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,7 +17,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError, Codes, NotFoundError
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
@ -113,12 +114,18 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet):
class PurgeHistoryRestServlet(ClientV1RestServlet): class PurgeHistoryRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns( PATTERNS = client_path_patterns(
"/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" "/admin/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
) )
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer)
"""
super(PurgeHistoryRestServlet, self).__init__(hs) super(PurgeHistoryRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.store = hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_id): def on_POST(self, request, room_id, event_id):
@ -128,9 +135,93 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
if not is_admin: if not is_admin:
raise AuthError(403, "You are not a server admin") raise AuthError(403, "You are not a server admin")
yield self.handlers.message_handler.purge_history(room_id, event_id) body = parse_json_object_from_request(request, allow_empty_body=True)
defer.returnValue((200, {})) delete_local_events = bool(body.get("delete_local_events", False))
# establish the topological ordering we should keep events from. The
# user can provide an event_id in the URL or the request body, or can
# provide a timestamp in the request body.
if event_id is None:
event_id = body.get('purge_up_to_event_id')
if event_id is not None:
event = yield self.store.get_event(event_id)
if event.room_id != room_id:
raise SynapseError(400, "Event is for wrong room.")
depth = event.depth
logger.info(
"[purge] purging up to depth %i (event_id %s)",
depth, event_id,
)
elif 'purge_up_to_ts' in body:
ts = body['purge_up_to_ts']
if not isinstance(ts, int):
raise SynapseError(
400, "purge_up_to_ts must be an int",
errcode=Codes.BAD_JSON,
)
stream_ordering = (
yield self.store.find_first_stream_ordering_after_ts(ts)
)
(_, depth, _) = (
yield self.store.get_room_event_after_stream_ordering(
room_id, stream_ordering,
)
)
logger.info(
"[purge] purging up to depth %i (received_ts %i => "
"stream_ordering %i)",
depth, ts, stream_ordering,
)
else:
raise SynapseError(
400,
"must specify purge_up_to_event_id or purge_up_to_ts",
errcode=Codes.BAD_JSON,
)
purge_id = yield self.handlers.message_handler.start_purge_history(
room_id, depth,
delete_local_events=delete_local_events,
)
defer.returnValue((200, {
"purge_id": purge_id,
}))
class PurgeHistoryStatusRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns(
"/admin/purge_history_status/(?P<purge_id>[^/]+)"
)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer)
"""
super(PurgeHistoryStatusRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, purge_id):
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
purge_status = self.handlers.message_handler.get_purge_status(purge_id)
if purge_status is None:
raise NotFoundError("purge id '%s' not found" % purge_id)
defer.returnValue((200, purge_status.asdict()))
class DeactivateAccountRestServlet(ClientV1RestServlet): class DeactivateAccountRestServlet(ClientV1RestServlet):
@ -171,6 +262,8 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id): def on_POST(self, request, room_id):
@ -203,8 +296,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
) )
new_room_id = info["room_id"] new_room_id = info["room_id"]
msg_handler = self.handlers.message_handler yield self.event_creation_handler.create_and_send_nonmember_event(
yield msg_handler.create_and_send_nonmember_event(
room_creator_requester, room_creator_requester,
{ {
"type": "m.room.message", "type": "m.room.message",
@ -230,7 +322,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
logger.info("Kicking %r from %r...", user_id, room_id) logger.info("Kicking %r from %r...", user_id, room_id)
target_requester = create_requester(user_id) target_requester = create_requester(user_id)
yield self.handlers.room_member_handler.update_membership( yield self.room_member_handler.update_membership(
requester=target_requester, requester=target_requester,
target=target_requester.user, target=target_requester.user,
room_id=room_id, room_id=room_id,
@ -239,9 +331,9 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
ratelimit=False ratelimit=False
) )
yield self.handlers.room_member_handler.forget(target_requester.user, room_id) yield self.room_member_handler.forget(target_requester.user, room_id)
yield self.handlers.room_member_handler.update_membership( yield self.room_member_handler.update_membership(
requester=target_requester, requester=target_requester,
target=target_requester.user, target=target_requester.user,
room_id=new_room_id, room_id=new_room_id,
@ -289,6 +381,27 @@ class QuarantineMediaInRoom(ClientV1RestServlet):
defer.returnValue((200, {"num_quarantined": num_quarantined})) defer.returnValue((200, {"num_quarantined": num_quarantined}))
class ListMediaInRoom(ClientV1RestServlet):
"""Lists all of the media in a given room.
"""
PATTERNS = client_path_patterns("/admin/room/(?P<room_id>[^/]+)/media")
def __init__(self, hs):
super(ListMediaInRoom, self).__init__(hs)
self.store = hs.get_datastore()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id)
defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs}))
class ResetPasswordRestServlet(ClientV1RestServlet): class ResetPasswordRestServlet(ClientV1RestServlet):
"""Post request to allow an administrator reset password for a user. """Post request to allow an administrator reset password for a user.
This needs user to have administrator access in Synapse. This needs user to have administrator access in Synapse.
@ -479,6 +592,7 @@ class SearchUsersRestServlet(ClientV1RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
WhoisRestServlet(hs).register(http_server) WhoisRestServlet(hs).register(http_server)
PurgeMediaCacheRestServlet(hs).register(http_server) PurgeMediaCacheRestServlet(hs).register(http_server)
PurgeHistoryStatusRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server)
PurgeHistoryRestServlet(hs).register(http_server) PurgeHistoryRestServlet(hs).register(http_server)
UsersRestServlet(hs).register(http_server) UsersRestServlet(hs).register(http_server)
@ -487,3 +601,4 @@ def register_servlets(hs, http_server):
SearchUsersRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server)
ShutdownRoomRestServlet(hs).register(http_server) ShutdownRoomRestServlet(hs).register(http_server)
QuarantineMediaInRoom(hs).register(http_server) QuarantineMediaInRoom(hs).register(http_server)
ListMediaInRoom(hs).register(http_server)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -82,6 +83,8 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs) super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.event_creation_hander = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
def register(self, http_server): def register(self, http_server):
# /room/$roomid/state/$eventtype # /room/$roomid/state/$eventtype
@ -154,7 +157,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
if event_type == EventTypes.Member: if event_type == EventTypes.Member:
membership = content.get("membership", None) membership = content.get("membership", None)
event = yield self.handlers.room_member_handler.update_membership( event = yield self.room_member_handler.update_membership(
requester, requester,
target=UserID.from_string(state_key), target=UserID.from_string(state_key),
room_id=room_id, room_id=room_id,
@ -162,15 +165,16 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
content=content, content=content,
) )
else: else:
msg_handler = self.handlers.message_handler event, context = yield self.event_creation_hander.create_event(
event, context = yield msg_handler.create_event(
requester, requester,
event_dict, event_dict,
token_id=requester.access_token_id, token_id=requester.access_token_id,
txn_id=txn_id, txn_id=txn_id,
) )
yield msg_handler.send_nonmember_event(requester, event, context) yield self.event_creation_hander.send_nonmember_event(
requester, event, context,
)
ret = {} ret = {}
if event: if event:
@ -183,7 +187,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs) super(RoomSendEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.event_creation_hander = hs.get_event_creation_handler()
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id] # /rooms/$roomid/send/$event_type[/$txn_id]
@ -205,8 +209,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
if 'ts' in request.args and requester.app_service: if 'ts' in request.args and requester.app_service:
event_dict['origin_server_ts'] = parse_integer(request, "ts", 0) event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
msg_handler = self.handlers.message_handler event = yield self.event_creation_hander.create_and_send_nonmember_event(
event = yield msg_handler.create_and_send_nonmember_event(
requester, requester,
event_dict, event_dict,
txn_id=txn_id, txn_id=txn_id,
@ -227,7 +230,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
class JoinRoomAliasServlet(ClientV1RestServlet): class JoinRoomAliasServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(JoinRoomAliasServlet, self).__init__(hs) super(JoinRoomAliasServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.room_member_handler = hs.get_room_member_handler()
def register(self, http_server): def register(self, http_server):
# /join/$room_identifier[/$txn_id] # /join/$room_identifier[/$txn_id]
@ -255,7 +258,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
except Exception: except Exception:
remote_room_hosts = None remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier): elif RoomAlias.is_valid(room_identifier):
handler = self.handlers.room_member_handler handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier) room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias) room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias)
room_id = room_id.to_string() room_id = room_id.to_string()
@ -264,7 +267,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
room_identifier, room_identifier,
)) ))
yield self.handlers.room_member_handler.update_membership( yield self.room_member_handler.update_membership(
requester=requester, requester=requester,
target=requester.user, target=requester.user,
room_id=room_id, room_id=room_id,
@ -560,7 +563,7 @@ class RoomEventContextServlet(ClientV1RestServlet):
class RoomForgetRestServlet(ClientV1RestServlet): class RoomForgetRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomForgetRestServlet, self).__init__(hs) super(RoomForgetRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.room_member_handler = hs.get_room_member_handler()
def register(self, http_server): def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
@ -573,7 +576,7 @@ class RoomForgetRestServlet(ClientV1RestServlet):
allow_guest=False, allow_guest=False,
) )
yield self.handlers.room_member_handler.forget( yield self.room_member_handler.forget(
user=requester.user, user=requester.user,
room_id=room_id, room_id=room_id,
) )
@ -591,12 +594,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomMembershipRestServlet, self).__init__(hs) super(RoomMembershipRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.room_member_handler = hs.get_room_member_handler()
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/[invite|join|leave] # /rooms/$roomid/[invite|join|leave]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/" PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
"(?P<membership_action>join|invite|leave|ban|unban|kick|forget)") "(?P<membership_action>join|invite|leave|ban|unban|kick)")
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -620,7 +623,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
content = {} content = {}
if membership_action == "invite" and self._has_3pid_invite_keys(content): if membership_action == "invite" and self._has_3pid_invite_keys(content):
yield self.handlers.room_member_handler.do_3pid_invite( yield self.room_member_handler.do_3pid_invite(
room_id, room_id,
requester.user, requester.user,
content["medium"], content["medium"],
@ -642,7 +645,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
if 'reason' in content and membership_action in ['kick', 'ban']: if 'reason' in content and membership_action in ['kick', 'ban']:
event_content = {'reason': content['reason']} event_content = {'reason': content['reason']}
yield self.handlers.room_member_handler.update_membership( yield self.room_member_handler.update_membership(
requester=requester, requester=requester,
target=target, target=target,
room_id=room_id, room_id=room_id,
@ -670,6 +673,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomRedactEventRestServlet, self).__init__(hs) super(RoomRedactEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler()
def register(self, http_server): def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
@ -680,8 +684,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
msg_handler = self.handlers.message_handler event = yield self.event_creation_handler.create_and_send_nonmember_event(
event = yield msg_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.Redaction, "type": EventTypes.Redaction,

View file

@ -183,7 +183,7 @@ class RegisterRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_handlers().registration_handler self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.room_member_handler = hs.get_handlers().room_member_handler self.room_member_handler = hs.get_room_member_handler()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()

View file

@ -472,8 +472,10 @@ class MediaRepository(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def generate_local_exact_thumbnail(self, media_id, t_width, t_height, def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
t_method, t_type): t_method, t_type, url_cache):
input_path = self.filepaths.local_media_filepath(media_id) input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
None, media_id, url_cache=url_cache,
))
thumbnailer = Thumbnailer(input_path) thumbnailer = Thumbnailer(input_path)
t_byte_source = yield make_deferred_yieldable(threads.deferToThread( t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@ -486,6 +488,7 @@ class MediaRepository(object):
file_info = FileInfo( file_info = FileInfo(
server_name=None, server_name=None,
file_id=media_id, file_id=media_id,
url_cache=url_cache,
thumbnail=True, thumbnail=True,
thumbnail_width=t_width, thumbnail_width=t_width,
thumbnail_height=t_height, thumbnail_height=t_height,
@ -512,7 +515,9 @@ class MediaRepository(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
t_width, t_height, t_method, t_type): t_width, t_height, t_method, t_type):
input_path = self.filepaths.remote_media_filepath(server_name, file_id) input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
server_name, file_id, url_cache=False,
))
thumbnailer = Thumbnailer(input_path) thumbnailer = Thumbnailer(input_path)
t_byte_source = yield make_deferred_yieldable(threads.deferToThread( t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@ -570,12 +575,9 @@ class MediaRepository(object):
if not requirements: if not requirements:
return return
if server_name: input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
input_path = self.filepaths.remote_media_filepath(server_name, file_id) server_name, file_id, url_cache=url_cache,
elif url_cache: ))
input_path = self.filepaths.url_cache_filepath(media_id)
else:
input_path = self.filepaths.local_media_filepath(media_id)
thumbnailer = Thumbnailer(input_path) thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width m_width = thumbnailer.width

View file

@ -18,6 +18,7 @@ from twisted.protocols.basic import FileSender
from ._base import Responder from ._base import Responder
from synapse.util.file_consumer import BackgroundFileConsumer
from synapse.util.logcontext import make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable
import contextlib import contextlib
@ -26,6 +27,7 @@ import logging
import shutil import shutil
import sys import sys
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,17 +58,13 @@ class MediaStorage(object):
Returns: Returns:
Deferred[str]: the file path written to in the primary media store Deferred[str]: the file path written to in the primary media store
""" """
path = self._file_info_to_path(file_info)
fname = os.path.join(self.local_media_directory, path)
dirname = os.path.dirname(fname)
if not os.path.exists(dirname):
os.makedirs(dirname)
with self.store_into_file(file_info) as (f, fname, finish_cb):
# Write to the main repository # Write to the main repository
yield make_deferred_yieldable(threads.deferToThread( yield make_deferred_yieldable(threads.deferToThread(
_write_file_synchronously, source, fname, _write_file_synchronously, source, f,
)) ))
yield finish_cb()
defer.returnValue(fname) defer.returnValue(fname)
@ -151,6 +149,37 @@ class MediaStorage(object):
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks
def ensure_media_is_in_local_cache(self, file_info):
"""Ensures that the given file is in the local cache. Attempts to
download it from storage providers if it isn't.
Args:
file_info (FileInfo)
Returns:
Deferred[str]: Full path to local file
"""
path = self._file_info_to_path(file_info)
local_path = os.path.join(self.local_media_directory, path)
if os.path.exists(local_path):
defer.returnValue(local_path)
dirname = os.path.dirname(local_path)
if not os.path.exists(dirname):
os.makedirs(dirname)
for provider in self.storage_providers:
res = yield provider.fetch(path, file_info)
if res:
with res:
consumer = BackgroundFileConsumer(open(local_path, "w"))
yield res.write_to_consumer(consumer)
yield consumer.wait()
defer.returnValue(local_path)
raise Exception("file could not be found")
def _file_info_to_path(self, file_info): def _file_info_to_path(self, file_info):
"""Converts file_info into a relative path. """Converts file_info into a relative path.
@ -201,21 +230,16 @@ class MediaStorage(object):
) )
def _write_file_synchronously(source, fname): def _write_file_synchronously(source, dest):
"""Write `source` to the path `fname` synchronously. Should be called """Write `source` to the file like `dest` synchronously. Should be called
from a thread. from a thread.
Args: Args:
source: A file like object to be written source: A file like object that's to be written
fname (str): Path to write to dest: A file like object to be written to
""" """
dirname = os.path.dirname(fname)
if not os.path.exists(dirname):
os.makedirs(dirname)
source.seek(0) # Ensure we read from the start of the file source.seek(0) # Ensure we read from the start of the file
with open(fname, "wb") as f: shutil.copyfileobj(source, dest)
shutil.copyfileobj(source, f)
class FileResponder(Responder): class FileResponder(Responder):
@ -228,9 +252,8 @@ class FileResponder(Responder):
def __init__(self, open_file): def __init__(self, open_file):
self.open_file = open_file self.open_file = open_file
@defer.inlineCallbacks
def write_to_consumer(self, consumer): def write_to_consumer(self, consumer):
yield FileSender().beginFileTransfer(self.open_file, consumer) return FileSender().beginFileTransfer(self.open_file, consumer)
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self.open_file.close() self.open_file.close()

View file

@ -12,6 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import cgi
import datetime
import errno
import fnmatch
import itertools
import logging
import os
import re
import shutil
import sys
import traceback
import ujson as json
import urlparse
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer from twisted.internet import defer
@ -33,18 +46,6 @@ from synapse.http.server import (
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii from synapse.util.stringutils import is_ascii
import os
import re
import fnmatch
import cgi
import ujson as json
import urlparse
import itertools
import datetime
import errno
import shutil
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -286,17 +287,28 @@ class PreviewUrlResource(Resource):
url_cache=True, url_cache=True,
) )
try:
with self.media_storage.store_into_file(file_info) as (f, fname, finish): with self.media_storage.store_into_file(file_info) as (f, fname, finish):
try:
logger.debug("Trying to get url '%s'" % url) logger.debug("Trying to get url '%s'" % url)
length, headers, uri, code = yield self.client.get_file( length, headers, uri, code = yield self.client.get_file(
url, output_stream=f, max_size=self.max_spider_size, url, output_stream=f, max_size=self.max_spider_size,
) )
except Exception as e:
# FIXME: pass through 404s and other error messages nicely # FIXME: pass through 404s and other error messages nicely
logger.warn("Error downloading %s: %r", url, e)
raise SynapseError(
500, "Failed to download content: %s" % (
traceback.format_exception_only(sys.exc_type, e),
),
Codes.UNKNOWN,
)
yield finish() yield finish()
try:
if "Content-Type" in headers:
media_type = headers["Content-Type"][0] media_type = headers["Content-Type"][0]
else:
media_type = "application/octet-stream"
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
content_disposition = headers.get("Content-Disposition", None) content_disposition = headers.get("Content-Disposition", None)
@ -336,10 +348,11 @@ class PreviewUrlResource(Resource):
) )
except Exception as e: except Exception as e:
raise SynapseError( logger.error("Error handling downloaded %s: %r", url, e)
500, ("Failed to download content: %s" % e), # TODO: we really ought to delete the downloaded file in this
Codes.UNKNOWN # case, since we won't have recorded it in the db, and will
) # therefore not expire it.
raise
defer.returnValue({ defer.returnValue({
"media_type": media_type, "media_type": media_type,

View file

@ -164,7 +164,8 @@ class ThumbnailResource(Resource):
# Okay, so we generate one. # Okay, so we generate one.
file_path = yield self.media_repo.generate_local_exact_thumbnail( file_path = yield self.media_repo.generate_local_exact_thumbnail(
media_id, desired_width, desired_height, desired_method, desired_type media_id, desired_width, desired_height, desired_method, desired_type,
url_cache=media_info["url_cache"],
) )
if file_path: if file_path:

View file

@ -32,8 +32,10 @@ from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.crypto.keyring import Keyring from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory from synapse.events.builder import EventBuilderFactory
from synapse.events.spamcheck import SpamChecker from synapse.events.spamcheck import SpamChecker
from synapse.federation import initialize_http_replication from synapse.federation.federation_client import FederationClient
from synapse.federation.federation_server import FederationServer
from synapse.federation.send_queue import FederationRemoteSendQueue from synapse.federation.send_queue import FederationRemoteSendQueue
from synapse.federation.federation_server import FederationHandlerRegistry
from synapse.federation.transport.client import TransportLayerClient from synapse.federation.transport.client import TransportLayerClient
from synapse.federation.transaction_queue import TransactionQueue from synapse.federation.transaction_queue import TransactionQueue
from synapse.handlers import Handlers from synapse.handlers import Handlers
@ -45,6 +47,7 @@ from synapse.handlers.device import DeviceHandler
from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.e2e_keys import E2eKeysHandler
from synapse.handlers.presence import PresenceHandler from synapse.handlers.presence import PresenceHandler
from synapse.handlers.room_list import RoomListHandler from synapse.handlers.room_list import RoomListHandler
from synapse.handlers.room_member import RoomMemberHandler
from synapse.handlers.set_password import SetPasswordHandler from synapse.handlers.set_password import SetPasswordHandler
from synapse.handlers.sync import SyncHandler from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler from synapse.handlers.typing import TypingHandler
@ -55,6 +58,7 @@ from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.user_directory import UserDirectoryHandler from synapse.handlers.user_directory import UserDirectoryHandler
from synapse.handlers.groups_local import GroupsLocalHandler from synapse.handlers.groups_local import GroupsLocalHandler
from synapse.handlers.profile import ProfileHandler from synapse.handlers.profile import ProfileHandler
from synapse.handlers.message import EventCreationHandler
from synapse.groups.groups_server import GroupsServerHandler from synapse.groups.groups_server import GroupsServerHandler
from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
@ -66,7 +70,7 @@ from synapse.rest.media.v1.media_repository import (
MediaRepository, MediaRepository,
MediaRepositoryResource, MediaRepositoryResource,
) )
from synapse.state import StateHandler from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.streams.events import EventSources from synapse.streams.events import EventSources
from synapse.util import Clock from synapse.util import Clock
@ -97,11 +101,13 @@ class HomeServer(object):
DEPENDENCIES = [ DEPENDENCIES = [
'http_client', 'http_client',
'db_pool', 'db_pool',
'replication_layer', 'federation_client',
'federation_server',
'handlers', 'handlers',
'v1auth', 'v1auth',
'auth', 'auth',
'state_handler', 'state_handler',
'state_resolution_handler',
'presence_handler', 'presence_handler',
'sync_handler', 'sync_handler',
'typing_handler', 'typing_handler',
@ -117,6 +123,7 @@ class HomeServer(object):
'application_service_handler', 'application_service_handler',
'device_message_handler', 'device_message_handler',
'profile_handler', 'profile_handler',
'event_creation_handler',
'deactivate_account_handler', 'deactivate_account_handler',
'set_password_handler', 'set_password_handler',
'notifier', 'notifier',
@ -142,6 +149,8 @@ class HomeServer(object):
'groups_attestation_signing', 'groups_attestation_signing',
'groups_attestation_renewer', 'groups_attestation_renewer',
'spam_checker', 'spam_checker',
'room_member_handler',
'federation_registry',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
@ -190,8 +199,11 @@ class HomeServer(object):
def get_ratelimiter(self): def get_ratelimiter(self):
return self.ratelimiter return self.ratelimiter
def build_replication_layer(self): def build_federation_client(self):
return initialize_http_replication(self) return FederationClient(self)
def build_federation_server(self):
return FederationServer(self)
def build_handlers(self): def build_handlers(self):
return Handlers(self) return Handlers(self)
@ -224,6 +236,9 @@ class HomeServer(object):
def build_state_handler(self): def build_state_handler(self):
return StateHandler(self) return StateHandler(self)
def build_state_resolution_handler(self):
return StateResolutionHandler(self)
def build_presence_handler(self): def build_presence_handler(self):
return PresenceHandler(self) return PresenceHandler(self)
@ -272,6 +287,9 @@ class HomeServer(object):
def build_profile_handler(self): def build_profile_handler(self):
return ProfileHandler(self) return ProfileHandler(self)
def build_event_creation_handler(self):
return EventCreationHandler(self)
def build_deactivate_account_handler(self): def build_deactivate_account_handler(self):
return DeactivateAccountHandler(self) return DeactivateAccountHandler(self)
@ -307,6 +325,23 @@ class HomeServer(object):
**self.db_config.get("args", {}) **self.db_config.get("args", {})
) )
def get_db_conn(self, run_new_connection=True):
"""Makes a new connection to the database, skipping the db pool
Returns:
Connection: a connection object implementing the PEP-249 spec
"""
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def build_media_repository_resource(self): def build_media_repository_resource(self):
# build the media repo resource. This indirects through the HomeServer # build the media repo resource. This indirects through the HomeServer
# to ensure that we only have a single instance of # to ensure that we only have a single instance of
@ -356,6 +391,12 @@ class HomeServer(object):
def build_spam_checker(self): def build_spam_checker(self):
return SpamChecker(self) return SpamChecker(self)
def build_room_member_handler(self):
return RoomMemberHandler(self)
def build_federation_registry(self):
return FederationHandlerRegistry()
def remove_pusher(self, app_id, push_key, user_id): def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)

View file

@ -34,6 +34,9 @@ class HomeServer(object):
def get_state_handler(self) -> synapse.state.StateHandler: def get_state_handler(self) -> synapse.state.StateHandler:
pass pass
def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
pass
def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
pass pass

View file

@ -58,7 +58,11 @@ class _StateCacheEntry(object):
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"] __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
def __init__(self, state, state_group, prev_group=None, delta_ids=None): def __init__(self, state, state_group, prev_group=None, delta_ids=None):
# dict[(str, str), str] map from (type, state_key) to event_id
self.state = frozendict(state) self.state = frozendict(state)
# the ID of a state group if one and only one is involved.
# otherwise, None otherwise?
self.state_group = state_group self.state_group = state_group
self.prev_group = prev_group self.prev_group = prev_group
@ -81,31 +85,19 @@ class _StateCacheEntry(object):
class StateHandler(object): class StateHandler(object):
""" Responsible for doing state conflict resolution. """Fetches bits of state from the stores, and does state resolution
where necessary
""" """
def __init__(self, hs): def __init__(self, hs):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.hs = hs self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
# dict of set of event_ids -> _StateCacheEntry.
self._state_cache = None
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
def start_caching(self): def start_caching(self):
logger.debug("start_caching") # TODO: remove this shim
self._state_resolution_handler.start_caching()
self._state_cache = ExpiringCache(
cache_name="state_cache",
clock=self.clock,
max_len=SIZE_OF_CACHE,
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True,
reset_expiry_on_get=True,
)
self._state_cache.start()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key="", def get_current_state(self, room_id, event_type=None, state_key="",
@ -127,7 +119,7 @@ class StateHandler(object):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state") logger.debug("calling resolve_state_groups from get_current_state")
ret = yield self.resolve_state_groups(room_id, latest_event_ids) ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state state = ret.state
if event_type: if event_type:
@ -146,19 +138,27 @@ class StateHandler(object):
defer.returnValue(state) defer.returnValue(state)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_state_ids(self, room_id, event_type=None, state_key="", def get_current_state_ids(self, room_id, latest_event_ids=None):
latest_event_ids=None): """Get the current state, or the state at a set of events, for a room
Args:
room_id (str):
latest_event_ids (iterable[str]|None): if given, the forward
extremities to resolve. If None, we look them up from the
database (via a cache)
Returns:
Deferred[dict[(str, str), str)]]: the state dict, mapping from
(event_type, state_key) -> event_id
"""
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state_ids") logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = yield self.resolve_state_groups(room_id, latest_event_ids) ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state state = ret.state
if event_type:
defer.returnValue(state.get((event_type, state_key)))
return
defer.returnValue(state) defer.returnValue(state)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -166,7 +166,7 @@ class StateHandler(object):
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_user_in_room") logger.debug("calling resolve_state_groups from get_current_user_in_room")
entry = yield self.resolve_state_groups(room_id, latest_event_ids) entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(room_id, entry) joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
defer.returnValue(joined_users) defer.returnValue(joined_users)
@ -175,7 +175,7 @@ class StateHandler(object):
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_hosts_in_room") logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
entry = yield self.resolve_state_groups(room_id, latest_event_ids) entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_hosts = yield self.store.get_joined_hosts(room_id, entry) joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
defer.returnValue(joined_hosts) defer.returnValue(joined_hosts)
@ -183,8 +183,15 @@ class StateHandler(object):
def compute_event_context(self, event, old_state=None): def compute_event_context(self, event, old_state=None):
"""Build an EventContext structure for the event. """Build an EventContext structure for the event.
This works out what the current state should be for the event, and
generates a new state group if necessary.
Args: Args:
event (synapse.events.EventBase): event (synapse.events.EventBase):
old_state (dict|None): The state at the event if it can't be
calculated from existing events. This is normally only specified
when receiving an event from federation where we don't have the
prev events for, e.g. when backfilling.
Returns: Returns:
synapse.events.snapshot.EventContext: synapse.events.snapshot.EventContext:
""" """
@ -208,15 +215,22 @@ class StateHandler(object):
context.current_state_ids = {} context.current_state_ids = {}
context.prev_state_ids = {} context.prev_state_ids = {}
context.prev_state_events = [] context.prev_state_events = []
context.state_group = self.store.get_next_state_group()
# We don't store state for outliers, so we don't generate a state
# froup for it.
context.state_group = None
defer.returnValue(context) defer.returnValue(context)
if old_state: if old_state:
# We already have the state, so we don't need to calculate it.
# Let's just correctly fill out the context and create a
# new state group for it.
context = EventContext() context = EventContext()
context.prev_state_ids = { context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state (s.type, s.state_key): s.event_id for s in old_state
} }
context.state_group = self.store.get_next_state_group()
if event.is_state(): if event.is_state():
key = (event.type, event.state_key) key = (event.type, event.state_key)
@ -229,11 +243,19 @@ class StateHandler(object):
else: else:
context.current_state_ids = context.prev_state_ids context.current_state_ids = context.prev_state_ids
context.state_group = yield self.store.store_state_group(
event.event_id,
event.room_id,
prev_group=None,
delta_ids=None,
current_state_ids=context.current_state_ids,
)
context.prev_state_events = [] context.prev_state_events = []
defer.returnValue(context) defer.returnValue(context)
logger.debug("calling resolve_state_groups from compute_event_context") logger.debug("calling resolve_state_groups from compute_event_context")
entry = yield self.resolve_state_groups( entry = yield self.resolve_state_groups_for_events(
event.room_id, [e for e, _ in event.prev_events], event.room_id, [e for e, _ in event.prev_events],
) )
@ -242,7 +264,8 @@ class StateHandler(object):
context = EventContext() context = EventContext()
context.prev_state_ids = curr_state context.prev_state_ids = curr_state
if event.is_state(): if event.is_state():
context.state_group = self.store.get_next_state_group() # If this is a state event then we need to create a new state
# group for the state after this event.
key = (event.type, event.state_key) key = (event.type, event.state_key)
if key in context.prev_state_ids: if key in context.prev_state_ids:
@ -253,38 +276,57 @@ class StateHandler(object):
context.current_state_ids[key] = event.event_id context.current_state_ids[key] = event.event_id
if entry.state_group: if entry.state_group:
# If the state at the event has a state group assigned then
# we can use that as the prev group
context.prev_group = entry.state_group context.prev_group = entry.state_group
context.delta_ids = { context.delta_ids = {
key: event.event_id key: event.event_id
} }
elif entry.prev_group: elif entry.prev_group:
# If the state at the event only has a prev group, then we can
# use that as a prev group too.
context.prev_group = entry.prev_group context.prev_group = entry.prev_group
context.delta_ids = dict(entry.delta_ids) context.delta_ids = dict(entry.delta_ids)
context.delta_ids[key] = event.event_id context.delta_ids[key] = event.event_id
else:
if entry.state_group is None:
entry.state_group = self.store.get_next_state_group()
entry.state_id = entry.state_group
context.state_group = entry.state_group context.state_group = yield self.store.store_state_group(
event.event_id,
event.room_id,
prev_group=context.prev_group,
delta_ids=context.delta_ids,
current_state_ids=context.current_state_ids,
)
else:
context.current_state_ids = context.prev_state_ids context.current_state_ids = context.prev_state_ids
context.prev_group = entry.prev_group context.prev_group = entry.prev_group
context.delta_ids = entry.delta_ids context.delta_ids = entry.delta_ids
if entry.state_group is None:
entry.state_group = yield self.store.store_state_group(
event.event_id,
event.room_id,
prev_group=entry.prev_group,
delta_ids=entry.delta_ids,
current_state_ids=context.current_state_ids,
)
entry.state_id = entry.state_group
context.state_group = entry.state_group
context.prev_state_events = [] context.prev_state_events = []
defer.returnValue(context) defer.returnValue(context)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function def resolve_state_groups_for_events(self, room_id, event_ids):
def resolve_state_groups(self, room_id, event_ids):
""" Given a list of event_ids this method fetches the state at each """ Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them. event, resolves conflicts between them and returns them.
Args:
room_id (str):
event_ids (list[str]):
Returns: Returns:
a Deferred tuple of (`state_group`, `state`, `prev_state`). Deferred[_StateCacheEntry]: resolved state
`state_group` is the name of a state group if one and only one is
involved. `state` is a map from (type, state_key) to event, and
`prev_state` is a list of event ids.
""" """
logger.debug("resolve_state_groups event_ids %s", event_ids) logger.debug("resolve_state_groups event_ids %s", event_ids)
@ -295,13 +337,7 @@ class StateHandler(object):
room_id, event_ids room_id, event_ids
) )
logger.debug( if len(state_groups_ids) == 1:
"resolve_state_groups state_groups %s",
state_groups_ids.keys()
)
group_names = frozenset(state_groups_ids.keys())
if len(group_names) == 1:
name, state_list = state_groups_ids.items().pop() name, state_list = state_groups_ids.items().pop()
prev_group, delta_ids = yield self.store.get_state_group_delta(name) prev_group, delta_ids = yield self.store.get_state_group_delta(name)
@ -313,6 +349,102 @@ class StateHandler(object):
delta_ids=delta_ids, delta_ids=delta_ids,
)) ))
result = yield self._state_resolution_handler.resolve_state_groups(
room_id, state_groups_ids, None, self._state_map_factory,
)
defer.returnValue(result)
def _state_map_factory(self, ev_ids):
return self.store.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
)
def resolve_events(self, state_sets, event):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
state_set_ids = [{
(ev.type, ev.state_key): ev.event_id
for ev in st
} for st in state_sets]
state_map = {
ev.event_id: ev
for st in state_sets
for ev in st
}
with Measure(self.clock, "state._resolve_events"):
new_state = resolve_events_with_state_map(state_set_ids, state_map)
new_state = {
key: state_map[ev_id] for key, ev_id in new_state.items()
}
return new_state
class StateResolutionHandler(object):
"""Responsible for doing state conflict resolution.
Note that the storage layer depends on this handler, so all functions must
be storage-independent.
"""
def __init__(self, hs):
self.clock = hs.get_clock()
# dict of set of event_ids -> _StateCacheEntry.
self._state_cache = None
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
def start_caching(self):
logger.debug("start_caching")
self._state_cache = ExpiringCache(
cache_name="state_cache",
clock=self.clock,
max_len=SIZE_OF_CACHE,
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True,
reset_expiry_on_get=True,
)
self._state_cache.start()
@defer.inlineCallbacks
@log_function
def resolve_state_groups(
self, room_id, state_groups_ids, event_map, state_map_factory,
):
"""Resolves conflicts between a set of state groups
Always generates a new state group (unless we hit the cache), so should
not be called for a single state group
Args:
room_id (str): room we are resolving for (used for logging)
state_groups_ids (dict[int, dict[(str, str), str]]):
map from state group id to the state in that state group
(where 'state' is a map from state key to event id)
event_map(dict[str,FrozenEvent]|None):
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing
events will be requested via state_map_factory.
If None, all events will be fetched via state_map_factory.
Returns:
Deferred[_StateCacheEntry]: resolved state
"""
logger.debug(
"resolve_state_groups state_groups %s",
state_groups_ids.keys()
)
group_names = frozenset(state_groups_ids.keys())
with (yield self.resolve_linearizer.queue(group_names)): with (yield self.resolve_linearizer.queue(group_names)):
if self._state_cache is not None: if self._state_cache is not None:
cache = self._state_cache.get(group_names, None) cache = self._state_cache.get(group_names, None)
@ -343,15 +475,18 @@ class StateHandler(object):
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_factory( new_state = yield resolve_events_with_factory(
state_groups_ids.values(), state_groups_ids.values(),
state_map_factory=lambda ev_ids: self.store.get_events( event_map=event_map,
ev_ids, get_prev_content=False, check_redacted=False, state_map_factory=state_map_factory,
),
) )
else: else:
new_state = { new_state = {
key: e_ids.pop() for key, e_ids in state.items() key: e_ids.pop() for key, e_ids in state.items()
} }
# if the new state matches any of the input state groups, we can
# use that state group again. Otherwise we will generate a state_id
# which will be used as a cache key for future resolutions, but
# not get persisted.
state_group = None state_group = None
new_state_event_ids = frozenset(new_state.values()) new_state_event_ids = frozenset(new_state.values())
for sg, events in state_groups_ids.items(): for sg, events in state_groups_ids.items():
@ -388,30 +523,6 @@ class StateHandler(object):
defer.returnValue(cache) defer.returnValue(cache)
def resolve_events(self, state_sets, event):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
state_set_ids = [{
(ev.type, ev.state_key): ev.event_id
for ev in st
} for st in state_sets]
state_map = {
ev.event_id: ev
for st in state_sets
for ev in st
}
with Measure(self.clock, "state._resolve_events"):
new_state = resolve_events_with_state_map(state_set_ids, state_map)
new_state = {
key: state_map[ev_id] for key, ev_id in new_state.items()
}
return new_state
def _ordered_events(events): def _ordered_events(events):
def key_func(e): def key_func(e):
@ -429,8 +540,8 @@ def resolve_events_with_state_map(state_sets, state_map):
state_sets. state_sets.
Returns Returns
dict[(str, str), synapse.events.FrozenEvent]: dict[(str, str), str]:
a map from (type, state_key) to event. a map from (type, state_key) to event_id.
""" """
if len(state_sets) == 1: if len(state_sets) == 1:
return state_sets[0] return state_sets[0]
@ -452,6 +563,21 @@ def _seperate(state_sets):
"""Takes the state_sets and figures out which keys are conflicted and """Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated which aren't. i.e., which have multiple different event_ids associated
with them in different state sets. with them in different state sets.
Args:
state_sets(list[dict[(str, str), str]]):
List of dicts of (type, state_key) -> event_id, which are the
different state groups to resolve.
Returns:
(dict[(str, str), str], dict[(str, str), set[str]]):
A tuple of (unconflicted_state, conflicted_state), where:
unconflicted_state is a dict mapping (type, state_key)->event_id
for unconflicted state keys.
conflicted_state is a dict mapping (type, state_key) to a set of
event ids for conflicted state keys.
""" """
unconflicted_state = dict(state_sets[0]) unconflicted_state = dict(state_sets[0])
conflicted_state = {} conflicted_state = {}
@ -482,18 +608,27 @@ def _seperate(state_sets):
@defer.inlineCallbacks @defer.inlineCallbacks
def resolve_events_with_factory(state_sets, state_map_factory): def resolve_events_with_factory(state_sets, event_map, state_map_factory):
""" """
Args: Args:
state_sets(list): List of dicts of (type, state_key) -> event_id, state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve. which are the different state groups to resolve.
event_map(dict[str,FrozenEvent]|None):
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing
events will be requested via state_map_factory.
If None, all events will be fetched via state_map_factory.
state_map_factory(func): will be called state_map_factory(func): will be called
with a list of event_ids that are needed, and should return with with a list of event_ids that are needed, and should return with
a Deferred of dict of event_id to event. a Deferred of dict of event_id to event.
Returns Returns
Deferred[dict[(str, str), synapse.events.FrozenEvent]]: Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event. a map from (type, state_key) to event_id.
""" """
if len(state_sets) == 1: if len(state_sets) == 1:
defer.returnValue(state_sets[0]) defer.returnValue(state_sets[0])
@ -507,12 +642,16 @@ def resolve_events_with_factory(state_sets, state_map_factory):
for event_ids in conflicted_state.itervalues() for event_ids in conflicted_state.itervalues()
for event_id in event_ids for event_id in event_ids
) )
if event_map is not None:
needed_events -= set(event_map.iterkeys())
logger.info("Asking for %d conflicted events", len(needed_events)) logger.info("Asking for %d conflicted events", len(needed_events))
# dict[str, FrozenEvent]: a map from state event id to event. Only includes # dict[str, FrozenEvent]: a map from state event id to event. Only includes
# the state events which are in conflict. # the state events which are in conflict (and those in event_map)
state_map = yield state_map_factory(needed_events) state_map = yield state_map_factory(needed_events)
if event_map is not None:
state_map.update(event_map)
# get the ids of the auth events which allow us to authenticate the # get the ids of the auth events which allow us to authenticate the
# conflicted state, picking only from the unconflicting state. # conflicted state, picking only from the unconflicting state.
@ -524,6 +663,8 @@ def resolve_events_with_factory(state_sets, state_map_factory):
new_needed_events = set(auth_events.itervalues()) new_needed_events = set(auth_events.itervalues())
new_needed_events -= needed_events new_needed_events -= needed_events
if event_map is not None:
new_needed_events -= set(event_map.iterkeys())
logger.info("Asking for %d auth events", len(new_needed_events)) logger.info("Asking for %d auth events", len(new_needed_events))

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -19,7 +20,6 @@ from synapse.storage.devices import DeviceStore
from .appservice import ( from .appservice import (
ApplicationServiceStore, ApplicationServiceTransactionStore ApplicationServiceStore, ApplicationServiceTransactionStore
) )
from ._base import LoggingTransaction
from .directory import DirectoryStore from .directory import DirectoryStore
from .events import EventsStore from .events import EventsStore
from .presence import PresenceStore, UserPresenceState from .presence import PresenceStore, UserPresenceState
@ -104,12 +104,6 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "events", "stream_ordering", step=-1, db_conn, "events", "stream_ordering", step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")] extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
) )
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
)
self._account_data_id_gen = StreamIdGenerator(
db_conn, "account_data_max_stream_id", "stream_id"
)
self._presence_id_gen = StreamIdGenerator( self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id" db_conn, "presence_stream", "stream_id"
) )
@ -124,7 +118,6 @@ class DataStore(RoomMemberStore, RoomStore,
) )
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
@ -147,27 +140,6 @@ class DataStore(RoomMemberStore, RoomStore,
else: else:
self._cache_id_gen = None self._cache_id_gen = None
events_max = self._stream_id_gen.get_current_token()
event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events",
entity_column="room_id",
stream_column="stream_ordering",
max_value=events_max,
)
self._events_stream_cache = StreamChangeCache(
"EventsRoomStreamChangeCache", min_event_val,
prefilled_cache=event_cache_prefill,
)
self._membership_stream_cache = StreamChangeCache(
"MembershipStreamChangeCache", events_max,
)
account_max = self._account_data_id_gen.get_current_token()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max,
)
self._presence_on_startup = self._get_active_presence(db_conn) self._presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self._get_cache_dict( presence_cache_prefill, min_presence_val = self._get_cache_dict(
@ -181,18 +153,6 @@ class DataStore(RoomMemberStore, RoomStore,
prefilled_cache=presence_cache_prefill prefilled_cache=presence_cache_prefill
) )
push_rules_prefill, push_rules_id = self._get_cache_dict(
db_conn, "push_rules_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self._push_rules_stream_id_gen.get_current_token()[0],
)
self.push_rules_stream_cache = StreamChangeCache(
"PushRulesStreamChangeCache", push_rules_id,
prefilled_cache=push_rules_prefill,
)
max_device_inbox_id = self._device_inbox_id_gen.get_current_token() max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
device_inbox_prefill, min_device_inbox_id = self._get_cache_dict( device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
db_conn, "device_inbox", db_conn, "device_inbox",
@ -227,6 +187,7 @@ class DataStore(RoomMemberStore, RoomStore,
"DeviceListFederationStreamChangeCache", device_list_max, "DeviceListFederationStreamChangeCache", device_list_max,
) )
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
db_conn, "current_state_delta_stream", db_conn, "current_state_delta_stream",
entity_column="room_id", entity_column="room_id",
@ -251,20 +212,6 @@ class DataStore(RoomMemberStore, RoomStore,
prefilled_cache=_group_updates_prefill, prefilled_cache=_group_updates_prefill,
) )
cur = LoggingTransaction(
db_conn.cursor(),
name="_find_stream_orderings_for_times_txn",
database_engine=self.database_engine,
after_callbacks=[],
final_callbacks=[],
)
self._find_stream_orderings_for_times_txn(cur)
cur.close()
self.find_stream_orderings_looping_call = self._clock.looping_call(
self._find_stream_orderings_for_times, 10 * 60 * 1000
)
self._stream_order_on_start = self.get_room_max_stream_ordering() self._stream_order_on_start = self.get_room_max_stream_ordering()
self._min_stream_order_on_start = self.get_room_min_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering()

View file

@ -48,16 +48,16 @@ class LoggingTransaction(object):
passed to the constructor. Adds logging and metrics to the .execute() passed to the constructor. Adds logging and metrics to the .execute()
method.""" method."""
__slots__ = [ __slots__ = [
"txn", "name", "database_engine", "after_callbacks", "final_callbacks", "txn", "name", "database_engine", "after_callbacks", "exception_callbacks",
] ]
def __init__(self, txn, name, database_engine, after_callbacks, def __init__(self, txn, name, database_engine, after_callbacks,
final_callbacks): exception_callbacks):
object.__setattr__(self, "txn", txn) object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name) object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine) object.__setattr__(self, "database_engine", database_engine)
object.__setattr__(self, "after_callbacks", after_callbacks) object.__setattr__(self, "after_callbacks", after_callbacks)
object.__setattr__(self, "final_callbacks", final_callbacks) object.__setattr__(self, "exception_callbacks", exception_callbacks)
def call_after(self, callback, *args, **kwargs): def call_after(self, callback, *args, **kwargs):
"""Call the given callback on the main twisted thread after the """Call the given callback on the main twisted thread after the
@ -66,8 +66,8 @@ class LoggingTransaction(object):
""" """
self.after_callbacks.append((callback, args, kwargs)) self.after_callbacks.append((callback, args, kwargs))
def call_finally(self, callback, *args, **kwargs): def call_on_exception(self, callback, *args, **kwargs):
self.final_callbacks.append((callback, args, kwargs)) self.exception_callbacks.append((callback, args, kwargs))
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.txn, name) return getattr(self.txn, name)
@ -215,7 +215,7 @@ class SQLBaseStore(object):
self._clock.looping_call(loop, 10000) self._clock.looping_call(loop, 10000)
def _new_transaction(self, conn, desc, after_callbacks, final_callbacks, def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks,
logging_context, func, *args, **kwargs): logging_context, func, *args, **kwargs):
start = time.time() * 1000 start = time.time() * 1000
txn_id = self._TXN_ID txn_id = self._TXN_ID
@ -236,7 +236,7 @@ class SQLBaseStore(object):
txn = conn.cursor() txn = conn.cursor()
txn = LoggingTransaction( txn = LoggingTransaction(
txn, name, self.database_engine, after_callbacks, txn, name, self.database_engine, after_callbacks,
final_callbacks, exception_callbacks,
) )
r = func(txn, *args, **kwargs) r = func(txn, *args, **kwargs)
conn.commit() conn.commit()
@ -308,11 +308,11 @@ class SQLBaseStore(object):
current_context = LoggingContext.current_context() current_context = LoggingContext.current_context()
after_callbacks = [] after_callbacks = []
final_callbacks = [] exception_callbacks = []
def inner_func(conn, *args, **kwargs): def inner_func(conn, *args, **kwargs):
return self._new_transaction( return self._new_transaction(
conn, desc, after_callbacks, final_callbacks, current_context, conn, desc, after_callbacks, exception_callbacks, current_context,
func, *args, **kwargs func, *args, **kwargs
) )
@ -321,9 +321,10 @@ class SQLBaseStore(object):
for after_callback, after_args, after_kwargs in after_callbacks: for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs) after_callback(*after_args, **after_kwargs)
finally: except: # noqa: E722, as we reraise the exception this is fine.
for after_callback, after_args, after_kwargs in final_callbacks: for after_callback, after_args, after_kwargs in exception_callbacks:
after_callback(*after_args, **after_kwargs) after_callback(*after_args, **after_kwargs)
raise
defer.returnValue(result) defer.returnValue(result)
@ -1000,7 +1001,8 @@ class SQLBaseStore(object):
# __exit__ called after the transaction finishes. # __exit__ called after the transaction finishes.
ctx = self._cache_id_gen.get_next() ctx = self._cache_id_gen.get_next()
stream_id = ctx.__enter__() stream_id = ctx.__enter__()
txn.call_finally(ctx.__exit__, None, None, None) txn.call_on_exception(ctx.__exit__, None, None, None)
txn.call_after(ctx.__exit__, None, None, None)
txn.call_after(self.hs.get_notifier().on_new_replication_data) txn.call_after(self.hs.get_notifier().on_new_replication_data)
self._simple_insert_txn( self._simple_insert_txn(

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,18 +14,46 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore
from twisted.internet import defer from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
import abc
import ujson as json import ujson as json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AccountDataStore(SQLBaseStore): class AccountDataWorkerStore(SQLBaseStore):
"""This is an abstract base class where subclasses must implement
`get_max_account_data_stream_id` which can be called in the initializer.
"""
# This ABCMeta metaclass ensures that we cannot be instantiated without
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
def __init__(self, db_conn, hs):
account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max,
)
super(AccountDataWorkerStore, self).__init__(db_conn, hs)
@abc.abstractmethod
def get_max_account_data_stream_id(self):
"""Get the current max stream ID for account data stream
Returns:
int
"""
raise NotImplementedError()
@cached() @cached()
def get_account_data_for_user(self, user_id): def get_account_data_for_user(self, user_id):
@ -104,6 +133,7 @@ class AccountDataStore(SQLBaseStore):
for row in rows for row in rows
}) })
@cached(num_args=2)
def get_account_data_for_room(self, user_id, room_id): def get_account_data_for_room(self, user_id, room_id):
"""Get all the client account_data for a user for a room. """Get all the client account_data for a user for a room.
@ -127,6 +157,38 @@ class AccountDataStore(SQLBaseStore):
"get_account_data_for_room", get_account_data_for_room_txn "get_account_data_for_room", get_account_data_for_room_txn
) )
@cached(num_args=3, max_entries=5000)
def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
"""Get the client account_data of given type for a user for a room.
Args:
user_id(str): The user to get the account_data for.
room_id(str): The room to get the account_data for.
account_data_type (str): The account data type to get.
Returns:
A deferred of the room account_data for that type, or None if
there isn't any set.
"""
def get_account_data_for_room_and_type_txn(txn):
content_json = self._simple_select_one_onecol_txn(
txn,
table="room_account_data",
keyvalues={
"user_id": user_id,
"room_id": room_id,
"account_data_type": account_data_type,
},
retcol="content",
allow_none=True
)
return json.loads(content_json) if content_json else None
return self.runInteraction(
"get_account_data_for_room_and_type",
get_account_data_for_room_and_type_txn,
)
def get_all_updated_account_data(self, last_global_id, last_room_id, def get_all_updated_account_data(self, last_global_id, last_room_id,
current_id, limit): current_id, limit):
"""Get all the client account_data that has changed on the server """Get all the client account_data that has changed on the server
@ -209,6 +271,36 @@ class AccountDataStore(SQLBaseStore):
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
) )
@cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
ignored_account_data = yield self.get_global_account_data_by_type_for_user(
"m.ignored_user_list", ignorer_user_id,
on_invalidate=cache_context.invalidate,
)
if not ignored_account_data:
defer.returnValue(False)
defer.returnValue(
ignored_user_id in ignored_account_data.get("ignored_users", {})
)
class AccountDataStore(AccountDataWorkerStore):
def __init__(self, db_conn, hs):
self._account_data_id_gen = StreamIdGenerator(
db_conn, "account_data_max_stream_id", "stream_id"
)
super(AccountDataStore, self).__init__(db_conn, hs)
def get_max_account_data_stream_id(self):
"""Get the current max stream id for the private user data stream
Returns:
A deferred int.
"""
return self._account_data_id_gen.get_current_token()
@defer.inlineCallbacks @defer.inlineCallbacks
def add_account_data_to_room(self, user_id, room_id, account_data_type, content): def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
"""Add some account_data to a room for a user. """Add some account_data to a room for a user.
@ -251,6 +343,10 @@ class AccountDataStore(SQLBaseStore):
self._account_data_stream_cache.entity_has_changed(user_id, next_id) self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,)) self.get_account_data_for_user.invalidate((user_id,))
self.get_account_data_for_room.invalidate((user_id, room_id,))
self.get_account_data_for_room_and_type.prefill(
(user_id, room_id, account_data_type,), content,
)
result = self._account_data_id_gen.get_current_token() result = self._account_data_id_gen.get_current_token()
defer.returnValue(result) defer.returnValue(result)
@ -321,16 +417,3 @@ class AccountDataStore(SQLBaseStore):
"update_account_data_max_stream_id", "update_account_data_max_stream_id",
_update, _update,
) )
@cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
ignored_account_data = yield self.get_global_account_data_by_type_for_user(
"m.ignored_user_list", ignorer_user_id,
on_invalidate=cache_context.invalidate,
)
if not ignored_account_data:
defer.returnValue(False)
defer.returnValue(
ignored_user_id in ignored_account_data.get("ignored_users", {})
)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -17,10 +18,9 @@ import re
import simplejson as json import simplejson as json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.appservice import AppServiceTransaction from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices from synapse.config.appservice import load_appservices
from synapse.storage.roommember import RoomsForUser from synapse.storage.events import EventsWorkerStore
from ._base import SQLBaseStore from ._base import SQLBaseStore
@ -46,17 +46,16 @@ def _make_exclusive_regex(services_cache):
return exclusive_user_regex return exclusive_user_regex
class ApplicationServiceStore(SQLBaseStore): class ApplicationServiceWorkerStore(SQLBaseStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(ApplicationServiceStore, self).__init__(db_conn, hs)
self.hostname = hs.hostname
self.services_cache = load_appservices( self.services_cache = load_appservices(
hs.hostname, hs.hostname,
hs.config.app_service_config_files hs.config.app_service_config_files
) )
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs)
def get_app_services(self): def get_app_services(self):
return self.services_cache return self.services_cache
@ -99,83 +98,30 @@ class ApplicationServiceStore(SQLBaseStore):
return service return service
return None return None
def get_app_service_rooms(self, service): def get_app_service_by_id(self, as_id):
"""Get a list of RoomsForUser for this application service. """Get the application service with the given appservice ID.
Application services may be "interested" in lots of rooms depending on
the room ID, the room aliases, or the members in the room. This function
takes all of these into account and returns a list of RoomsForUser which
represent the entire list of room IDs that this application service
wants to know about.
Args: Args:
service: The application service to get a room list for. as_id (str): The application service ID.
Returns: Returns:
A list of RoomsForUser. synapse.appservice.ApplicationService or None.
""" """
return self.runInteraction( for service in self.services_cache:
"get_app_service_rooms", if service.id == as_id:
self._get_app_service_rooms_txn, return service
service, return None
)
def _get_app_service_rooms_txn(self, txn, service):
# get all rooms matching the room ID regex.
room_entries = self._simple_select_list_txn(
txn=txn, table="rooms", keyvalues=None, retcols=["room_id"]
)
matching_room_list = set([
r["room_id"] for r in room_entries if
service.is_interested_in_room(r["room_id"])
])
# resolve room IDs for matching room alias regex.
room_alias_mappings = self._simple_select_list_txn(
txn=txn, table="room_aliases", keyvalues=None,
retcols=["room_id", "room_alias"]
)
matching_room_list |= set([
r["room_id"] for r in room_alias_mappings if
service.is_interested_in_alias(r["room_alias"])
])
# get all rooms for every user for this AS. This is scoped to users on
# this HS only.
user_list = self._simple_select_list_txn(
txn=txn, table="users", keyvalues=None, retcols=["name"]
)
user_list = [
u["name"] for u in user_list if
service.is_interested_in_user(u["name"])
]
rooms_for_user_matching_user_id = set() # RoomsForUser list
for user_id in user_list:
# FIXME: This assumes this store is linked with RoomMemberStore :(
rooms_for_user = self._get_rooms_for_user_where_membership_is_txn(
txn=txn,
user_id=user_id,
membership_list=[Membership.JOIN]
)
rooms_for_user_matching_user_id |= set(rooms_for_user)
# make RoomsForUser tuples for room ids and aliases which are not in the
# main rooms_for_user_list - e.g. they are rooms which do not have AS
# registered users in it.
known_room_ids = [r.room_id for r in rooms_for_user_matching_user_id]
missing_rooms_for_user = [
RoomsForUser(r, service.sender, "join") for r in
matching_room_list if r not in known_room_ids
]
rooms_for_user_matching_user_id |= set(missing_rooms_for_user)
return rooms_for_user_matching_user_id
class ApplicationServiceTransactionStore(SQLBaseStore): class ApplicationServiceStore(ApplicationServiceWorkerStore):
# This is currently empty due to there not being any AS storage functions
# that can't be run on the workers. Since this may change in future, and
# to keep consistency with the other stores, we keep this empty class for
# now.
pass
def __init__(self, db_conn, hs):
super(ApplicationServiceTransactionStore, self).__init__(db_conn, hs)
class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
EventsWorkerStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_appservices_by_state(self, state): def get_appservices_by_state(self, state):
"""Get a list of application services based on their state. """Get a list of application services based on their state.
@ -420,3 +366,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
events = yield self._get_events(event_ids) events = yield self._get_events(event_ids)
defer.returnValue((upper_bound, events)) defer.returnValue((upper_bound, events))
class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore):
# This is currently empty due to there not being any AS storage functions
# that can't be run on the workers. Since this may change in future, and
# to keep consistency with the other stores, we keep this empty class for
# now.
pass

View file

@ -242,6 +242,25 @@ class BackgroundUpdateStore(SQLBaseStore):
""" """
self._background_update_handlers[update_name] = update_handler self._background_update_handlers[update_name] = update_handler
def register_noop_background_update(self, update_name):
"""Register a noop handler for a background update.
This is useful when we previously did a background update, but no
longer wish to do the update. In this case the background update should
be removed from the schema delta files, but there may still be some
users who have the background update queued, so this method should
also be called to clear the update.
Args:
update_name (str): Name of update
"""
@defer.inlineCallbacks
def noop_update(progress, batch_size):
yield self._end_background_update(update_name)
defer.returnValue(1)
self.register_background_update_handler(update_name, noop_update)
def register_background_index_update(self, update_name, index_name, def register_background_index_update(self, update_name, index_name,
table, columns, where_clause=None, table, columns, where_clause=None,
unique=False, unique=False,

View file

@ -29,8 +29,7 @@ RoomAliasMapping = namedtuple(
) )
class DirectoryStore(SQLBaseStore): class DirectoryWorkerStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_association_from_room_alias(self, room_alias): def get_association_from_room_alias(self, room_alias):
""" Get's the room_id and server list for a given room_alias """ Get's the room_id and server list for a given room_alias
@ -69,6 +68,28 @@ class DirectoryStore(SQLBaseStore):
RoomAliasMapping(room_id, room_alias.to_string(), servers) RoomAliasMapping(room_id, room_alias.to_string(), servers)
) )
def get_room_alias_creator(self, room_alias):
return self._simple_select_one_onecol(
table="room_aliases",
keyvalues={
"room_alias": room_alias,
},
retcol="creator",
desc="get_room_alias_creator",
allow_none=True
)
@cached(max_entries=5000)
def get_aliases_for_room(self, room_id):
return self._simple_select_onecol(
"room_aliases",
{"room_id": room_id},
"room_alias",
desc="get_aliases_for_room",
)
class DirectoryStore(DirectoryWorkerStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def create_room_alias_association(self, room_alias, room_id, servers, creator=None): def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
""" Creates an associatin between a room alias and room_id/servers """ Creates an associatin between a room alias and room_id/servers
@ -116,17 +137,6 @@ class DirectoryStore(SQLBaseStore):
) )
defer.returnValue(ret) defer.returnValue(ret)
def get_room_alias_creator(self, room_alias):
return self._simple_select_one_onecol(
table="room_aliases",
keyvalues={
"room_alias": room_alias,
},
retcol="creator",
desc="get_room_alias_creator",
allow_none=True
)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_room_alias(self, room_alias): def delete_room_alias(self, room_alias):
room_id = yield self.runInteraction( room_id = yield self.runInteraction(
@ -135,7 +145,6 @@ class DirectoryStore(SQLBaseStore):
room_alias, room_alias,
) )
self.get_aliases_for_room.invalidate((room_id,))
defer.returnValue(room_id) defer.returnValue(room_id)
def _delete_room_alias_txn(self, txn, room_alias): def _delete_room_alias_txn(self, txn, room_alias):
@ -160,17 +169,12 @@ class DirectoryStore(SQLBaseStore):
(room_alias.to_string(),) (room_alias.to_string(),)
) )
return room_id self._invalidate_cache_and_stream(
txn, self.get_aliases_for_room, (room_id,)
@cached(max_entries=5000)
def get_aliases_for_room(self, room_id):
return self._simple_select_onecol(
"room_aliases",
{"room_id": room_id},
"room_alias",
desc="get_aliases_for_room",
) )
return room_id
def update_aliases_for_room(self, old_room_id, new_room_id, creator): def update_aliases_for_room(self, old_room_id, new_room_id, creator):
def _update_aliases_for_room_txn(txn): def _update_aliases_for_room_txn(txn):
sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?" sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?"

View file

@ -62,3 +62,9 @@ class PostgresEngine(object):
def lock_table(self, txn, table): def lock_table(self, txn, table):
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,)) txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
def get_next_state_group_id(self, txn):
"""Returns an int that can be used as a new state_group ID
"""
txn.execute("SELECT nextval('state_group_id_seq')")
return txn.fetchone()[0]

View file

@ -16,6 +16,7 @@
from synapse.storage.prepare_database import prepare_database from synapse.storage.prepare_database import prepare_database
import struct import struct
import threading
class Sqlite3Engine(object): class Sqlite3Engine(object):
@ -24,6 +25,11 @@ class Sqlite3Engine(object):
def __init__(self, database_module, database_config): def __init__(self, database_module, database_config):
self.module = database_module self.module = database_module
# The current max state_group, or None if we haven't looked
# in the DB yet.
self._current_state_group_id = None
self._current_state_group_id_lock = threading.Lock()
def check_database(self, txn): def check_database(self, txn):
pass pass
@ -43,6 +49,19 @@ class Sqlite3Engine(object):
def lock_table(self, txn, table): def lock_table(self, txn, table):
return return
def get_next_state_group_id(self, txn):
"""Returns an int that can be used as a new state_group ID
"""
# We do application locking here since if we're using sqlite then
# we are a single process synapse.
with self._current_state_group_id_lock:
if self._current_state_group_id is None:
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
self._current_state_group_id = txn.fetchone()[0]
self._current_state_group_id += 1
return self._current_state_group_id
# Following functions taken from: https://github.com/coleifer/peewee # Following functions taken from: https://github.com/coleifer/peewee

View file

@ -15,7 +15,10 @@
from twisted.internet import defer from twisted.internet import defer
from ._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.events import EventsWorkerStore
from synapse.storage.signatures import SignatureWorkerStore
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from unpaddedbase64 import encode_base64 from unpaddedbase64 import encode_base64
@ -27,30 +30,8 @@ from Queue import PriorityQueue, Empty
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EventFederationStore(SQLBaseStore): class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
""" Responsible for storing and serving up the various graphs associated SQLBaseStore):
with an event. Including the main event graph and the auth chains for an
event.
Also has methods for getting the front (latest) and back (oldest) edges
of the event graphs. These are used to generate the parents for new events
and backfilling from another server respectively.
"""
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
def __init__(self, db_conn, hs):
super(EventFederationStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.EVENT_AUTH_STATE_ONLY,
self._background_delete_non_state_event_auth,
)
hs.get_clock().looping_call(
self._delete_old_forward_extrem_cache, 60 * 60 * 1000
)
def get_auth_chain(self, event_ids, include_given=False): def get_auth_chain(self, event_ids, include_given=False):
"""Get auth events for given event_ids. The events *must* be state events. """Get auth events for given event_ids. The events *must* be state events.
@ -228,88 +209,6 @@ class EventFederationStore(SQLBaseStore):
return int(min_depth) if min_depth is not None else None return int(min_depth) if min_depth is not None else None
def _update_min_depth_for_room_txn(self, txn, room_id, depth):
min_depth = self._get_min_depth_interaction(txn, room_id)
if min_depth and depth >= min_depth:
return
self._simple_upsert_txn(
txn,
table="room_depth",
keyvalues={
"room_id": room_id,
},
values={
"min_depth": depth,
},
)
def _handle_mult_prev_events(self, txn, events):
"""
For the given event, update the event edges table and forward and
backward extremities tables.
"""
self._simple_insert_many_txn(
txn,
table="event_edges",
values=[
{
"event_id": ev.event_id,
"prev_event_id": e_id,
"room_id": ev.room_id,
"is_state": False,
}
for ev in events
for e_id, _ in ev.prev_events
],
)
self._update_backward_extremeties(txn, events)
def _update_backward_extremeties(self, txn, events):
"""Updates the event_backward_extremities tables based on the new/updated
events being persisted.
This is called for new events *and* for events that were outliers, but
are now being persisted as non-outliers.
Forward extremities are handled when we first start persisting the events.
"""
events_by_room = {}
for ev in events:
events_by_room.setdefault(ev.room_id, []).append(ev)
query = (
"INSERT INTO event_backward_extremities (event_id, room_id)"
" SELECT ?, ? WHERE NOT EXISTS ("
" SELECT 1 FROM event_backward_extremities"
" WHERE event_id = ? AND room_id = ?"
" )"
" AND NOT EXISTS ("
" SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
" AND outlier = ?"
" )"
)
txn.executemany(query, [
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
for ev in events for e_id, _ in ev.prev_events
if not ev.internal_metadata.is_outlier()
])
query = (
"DELETE FROM event_backward_extremities"
" WHERE event_id = ? AND room_id = ?"
)
txn.executemany(
query,
[
(ev.event_id, ev.room_id) for ev in events
if not ev.internal_metadata.is_outlier()
]
)
def get_forward_extremeties_for_room(self, room_id, stream_ordering): def get_forward_extremeties_for_room(self, room_id, stream_ordering):
"""For a given room_id and stream_ordering, return the forward """For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time". extremeties of the room at that point in "time".
@ -371,28 +270,6 @@ class EventFederationStore(SQLBaseStore):
get_forward_extremeties_for_room_txn get_forward_extremeties_for_room_txn
) )
def _delete_old_forward_extrem_cache(self):
def _delete_old_forward_extrem_cache_txn(txn):
# Delete entries older than a month, while making sure we don't delete
# the only entries for a room.
sql = ("""
DELETE FROM stream_ordering_to_exterm
WHERE
room_id IN (
SELECT room_id
FROM stream_ordering_to_exterm
WHERE stream_ordering > ?
) AND stream_ordering < ?
""")
txn.execute(
sql,
(self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
)
return self.runInteraction(
"_delete_old_forward_extrem_cache",
_delete_old_forward_extrem_cache_txn
)
def get_backfill_events(self, room_id, event_list, limit): def get_backfill_events(self, room_id, event_list, limit):
"""Get a list of Events for a given topic that occurred before (and """Get a list of Events for a given topic that occurred before (and
including) the events in event_list. Return a list of max size `limit` including) the events in event_list. Return a list of max size `limit`
@ -522,6 +399,135 @@ class EventFederationStore(SQLBaseStore):
return event_results return event_results
class EventFederationStore(EventFederationWorkerStore):
""" Responsible for storing and serving up the various graphs associated
with an event. Including the main event graph and the auth chains for an
event.
Also has methods for getting the front (latest) and back (oldest) edges
of the event graphs. These are used to generate the parents for new events
and backfilling from another server respectively.
"""
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
def __init__(self, db_conn, hs):
super(EventFederationStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.EVENT_AUTH_STATE_ONLY,
self._background_delete_non_state_event_auth,
)
hs.get_clock().looping_call(
self._delete_old_forward_extrem_cache, 60 * 60 * 1000
)
def _update_min_depth_for_room_txn(self, txn, room_id, depth):
min_depth = self._get_min_depth_interaction(txn, room_id)
if min_depth and depth >= min_depth:
return
self._simple_upsert_txn(
txn,
table="room_depth",
keyvalues={
"room_id": room_id,
},
values={
"min_depth": depth,
},
)
def _handle_mult_prev_events(self, txn, events):
"""
For the given event, update the event edges table and forward and
backward extremities tables.
"""
self._simple_insert_many_txn(
txn,
table="event_edges",
values=[
{
"event_id": ev.event_id,
"prev_event_id": e_id,
"room_id": ev.room_id,
"is_state": False,
}
for ev in events
for e_id, _ in ev.prev_events
],
)
self._update_backward_extremeties(txn, events)
def _update_backward_extremeties(self, txn, events):
"""Updates the event_backward_extremities tables based on the new/updated
events being persisted.
This is called for new events *and* for events that were outliers, but
are now being persisted as non-outliers.
Forward extremities are handled when we first start persisting the events.
"""
events_by_room = {}
for ev in events:
events_by_room.setdefault(ev.room_id, []).append(ev)
query = (
"INSERT INTO event_backward_extremities (event_id, room_id)"
" SELECT ?, ? WHERE NOT EXISTS ("
" SELECT 1 FROM event_backward_extremities"
" WHERE event_id = ? AND room_id = ?"
" )"
" AND NOT EXISTS ("
" SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
" AND outlier = ?"
" )"
)
txn.executemany(query, [
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
for ev in events for e_id, _ in ev.prev_events
if not ev.internal_metadata.is_outlier()
])
query = (
"DELETE FROM event_backward_extremities"
" WHERE event_id = ? AND room_id = ?"
)
txn.executemany(
query,
[
(ev.event_id, ev.room_id) for ev in events
if not ev.internal_metadata.is_outlier()
]
)
def _delete_old_forward_extrem_cache(self):
def _delete_old_forward_extrem_cache_txn(txn):
# Delete entries older than a month, while making sure we don't delete
# the only entries for a room.
sql = ("""
DELETE FROM stream_ordering_to_exterm
WHERE
room_id IN (
SELECT room_id
FROM stream_ordering_to_exterm
WHERE stream_ordering > ?
) AND stream_ordering < ?
""")
txn.execute(
sql,
(self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
)
return self.runInteraction(
"_delete_old_forward_extrem_cache",
_delete_old_forward_extrem_cache_txn
)
def clean_room_for_join(self, room_id): def clean_room_for_join(self, room_id):
return self.runInteraction( return self.runInteraction(
"clean_room_for_join", "clean_room_for_join",

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore from synapse.storage._base import SQLBaseStore, LoggingTransaction
from twisted.internet import defer from twisted.internet import defer
from synapse.util.async import sleep from synapse.util.async import sleep
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
@ -62,60 +63,28 @@ def _deserialize_action(actions, is_highlight):
return DEFAULT_NOTIF_ACTION return DEFAULT_NOTIF_ACTION
class EventPushActionsStore(SQLBaseStore): class EventPushActionsWorkerStore(SQLBaseStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(EventPushActionsStore, self).__init__(db_conn, hs) super(EventPushActionsWorkerStore, self).__init__(db_conn, hs)
self.register_background_index_update( # These get correctly set by _find_stream_orderings_for_times_txn
self.EPA_HIGHLIGHT_INDEX, self.stream_ordering_month_ago = None
index_name="event_push_actions_u_highlight", self.stream_ordering_day_ago = None
table="event_push_actions",
columns=["user_id", "stream_ordering"], cur = LoggingTransaction(
db_conn.cursor(),
name="_find_stream_orderings_for_times_txn",
database_engine=self.database_engine,
after_callbacks=[],
exception_callbacks=[],
) )
self._find_stream_orderings_for_times_txn(cur)
cur.close()
self.register_background_index_update( self.find_stream_orderings_looping_call = self._clock.looping_call(
"event_push_actions_highlights_index", self._find_stream_orderings_for_times, 10 * 60 * 1000
index_name="event_push_actions_highlights_index",
table="event_push_actions",
columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
where_clause="highlight=1"
) )
self._doing_notif_rotation = False
self._rotate_notif_loop = self._clock.looping_call(
self._rotate_notifs, 30 * 60 * 1000
)
def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
"""
Args:
event: the event set actions for
tuples: list of tuples of (user_id, actions)
"""
values = []
for uid, actions in tuples:
is_highlight = 1 if _action_has_highlight(actions) else 0
values.append({
'room_id': event.room_id,
'event_id': event.event_id,
'user_id': uid,
'actions': _serialize_action(actions, is_highlight),
'stream_ordering': event.internal_metadata.stream_ordering,
'topological_ordering': event.depth,
'notif': 1,
'highlight': is_highlight,
})
for uid, __ in tuples:
txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(event.room_id, uid)
)
self._simple_insert_many_txn(txn, "event_push_actions", values)
@cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000) @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
def get_unread_event_push_actions_by_room_for_user( def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id self, room_id, user_id, last_read_event_id
@ -432,6 +401,280 @@ class EventPushActionsStore(SQLBaseStore):
# Now return the first `limit` # Now return the first `limit`
defer.returnValue(notifs[:limit]) defer.returnValue(notifs[:limit])
def add_push_actions_to_staging(self, event_id, user_id_actions):
"""Add the push actions for the event to the push action staging area.
Args:
event_id (str)
user_id_actions (dict[str, list[dict|str])]): A dictionary mapping
user_id to list of push actions, where an action can either be
a string or dict.
Returns:
Deferred
"""
if not user_id_actions:
return
# This is a helper function for generating the necessary tuple that
# can be used to inert into the `event_push_actions_staging` table.
def _gen_entry(user_id, actions):
is_highlight = 1 if _action_has_highlight(actions) else 0
return (
event_id, # event_id column
user_id, # user_id column
_serialize_action(actions, is_highlight), # actions column
1, # notif column
is_highlight, # highlight column
)
def _add_push_actions_to_staging_txn(txn):
# We don't use _simple_insert_many here to avoid the overhead
# of generating lists of dicts.
sql = """
INSERT INTO event_push_actions_staging
(event_id, user_id, actions, notif, highlight)
VALUES (?, ?, ?, ?, ?)
"""
txn.executemany(sql, (
_gen_entry(user_id, actions)
for user_id, actions in user_id_actions.iteritems()
))
return self.runInteraction(
"add_push_actions_to_staging", _add_push_actions_to_staging_txn
)
def remove_push_actions_from_staging(self, event_id):
"""Called if we failed to persist the event to ensure that stale push
actions don't build up in the DB
Args:
event_id (str)
"""
return self._simple_delete(
table="event_push_actions_staging",
keyvalues={
"event_id": event_id,
},
desc="remove_push_actions_from_staging",
)
@defer.inlineCallbacks
def _find_stream_orderings_for_times(self):
yield self.runInteraction(
"_find_stream_orderings_for_times",
self._find_stream_orderings_for_times_txn
)
def _find_stream_orderings_for_times_txn(self, txn):
logger.info("Searching for stream ordering 1 month ago")
self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
)
logger.info(
"Found stream ordering 1 month ago: it's %d",
self.stream_ordering_month_ago
)
logger.info("Searching for stream ordering 1 day ago")
self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn(
txn, self._clock.time_msec() - 24 * 60 * 60 * 1000
)
logger.info(
"Found stream ordering 1 day ago: it's %d",
self.stream_ordering_day_ago
)
def find_first_stream_ordering_after_ts(self, ts):
"""Gets the stream ordering corresponding to a given timestamp.
Specifically, finds the stream_ordering of the first event that was
received on or after the timestamp. This is done by a binary search on
the events table, since there is no index on received_ts, so is
relatively slow.
Args:
ts (int): timestamp in millis
Returns:
Deferred[int]: stream ordering of the first event received on/after
the timestamp
"""
return self.runInteraction(
"_find_first_stream_ordering_after_ts_txn",
self._find_first_stream_ordering_after_ts_txn,
ts,
)
@staticmethod
def _find_first_stream_ordering_after_ts_txn(txn, ts):
"""
Find the stream_ordering of the first event that was received on or
after a given timestamp. This is relatively slow as there is no index
on received_ts but we can then use this to delete push actions before
this.
received_ts must necessarily be in the same order as stream_ordering
and stream_ordering is indexed, so we manually binary search using
stream_ordering
Args:
txn (twisted.enterprise.adbapi.Transaction):
ts (int): timestamp to search for
Returns:
int: stream ordering
"""
txn.execute("SELECT MAX(stream_ordering) FROM events")
max_stream_ordering = txn.fetchone()[0]
if max_stream_ordering is None:
return 0
# We want the first stream_ordering in which received_ts is greater
# than or equal to ts. Call this point X.
#
# We maintain the invariants:
#
# range_start <= X <= range_end
#
range_start = 0
range_end = max_stream_ordering + 1
# Given a stream_ordering, look up the timestamp at that
# stream_ordering.
#
# The array may be sparse (we may be missing some stream_orderings).
# We treat the gaps as the same as having the same value as the
# preceding entry, because we will pick the lowest stream_ordering
# which satisfies our requirement of received_ts >= ts.
#
# For example, if our array of events indexed by stream_ordering is
# [10, <none>, 20], we should treat this as being equivalent to
# [10, 10, 20].
#
sql = (
"SELECT received_ts FROM events"
" WHERE stream_ordering <= ?"
" ORDER BY stream_ordering DESC"
" LIMIT 1"
)
while range_end - range_start > 0:
middle = (range_end + range_start) // 2
txn.execute(sql, (middle,))
row = txn.fetchone()
if row is None:
# no rows with stream_ordering<=middle
range_start = middle + 1
continue
middle_ts = row[0]
if ts > middle_ts:
# we got a timestamp lower than the one we were looking for.
# definitely need to look higher: X > middle.
range_start = middle + 1
else:
# we got a timestamp higher than (or the same as) the one we
# were looking for. We aren't yet sure about the point we
# looked up, but we can be sure that X <= middle.
range_end = middle
return range_end
class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
def __init__(self, db_conn, hs):
super(EventPushActionsStore, self).__init__(db_conn, hs)
self.register_background_index_update(
self.EPA_HIGHLIGHT_INDEX,
index_name="event_push_actions_u_highlight",
table="event_push_actions",
columns=["user_id", "stream_ordering"],
)
self.register_background_index_update(
"event_push_actions_highlights_index",
index_name="event_push_actions_highlights_index",
table="event_push_actions",
columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
where_clause="highlight=1"
)
self._doing_notif_rotation = False
self._rotate_notif_loop = self._clock.looping_call(
self._rotate_notifs, 30 * 60 * 1000
)
def _set_push_actions_for_event_and_users_txn(self, txn, events_and_contexts,
all_events_and_contexts):
"""Handles moving push actions from staging table to main
event_push_actions table for all events in `events_and_contexts`.
Also ensures that all events in `all_events_and_contexts` are removed
from the push action staging area.
Args:
events_and_contexts (list[(EventBase, EventContext)]): events
we are persisting
all_events_and_contexts (list[(EventBase, EventContext)]): all
events that we were going to persist. This includes events
we've already persisted, etc, that wouldn't appear in
events_and_context.
"""
sql = """
INSERT INTO event_push_actions (
room_id, event_id, user_id, actions, stream_ordering,
topological_ordering, notif, highlight
)
SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
FROM event_push_actions_staging
WHERE event_id = ?
"""
if events_and_contexts:
txn.executemany(sql, (
(
event.room_id, event.internal_metadata.stream_ordering,
event.depth, event.event_id,
)
for event, _ in events_and_contexts
))
for event, _ in events_and_contexts:
user_ids = self._simple_select_onecol_txn(
txn,
table="event_push_actions_staging",
keyvalues={
"event_id": event.event_id,
},
retcol="user_id",
)
for uid in user_ids:
txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(event.room_id, uid,)
)
# Now we delete the staging area for *all* events that were being
# persisted.
txn.executemany(
"DELETE FROM event_push_actions_staging WHERE event_id = ?",
(
(event.event_id,)
for event, _ in all_events_and_contexts
)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_push_actions_for_user(self, user_id, before=None, limit=50, def get_push_actions_for_user(self, user_id, before=None, limit=50,
only_highlight=False): only_highlight=False):
@ -550,69 +793,6 @@ class EventPushActionsStore(SQLBaseStore):
WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
""", (room_id, user_id, stream_ordering)) """, (room_id, user_id, stream_ordering))
@defer.inlineCallbacks
def _find_stream_orderings_for_times(self):
yield self.runInteraction(
"_find_stream_orderings_for_times",
self._find_stream_orderings_for_times_txn
)
def _find_stream_orderings_for_times_txn(self, txn):
logger.info("Searching for stream ordering 1 month ago")
self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
)
logger.info(
"Found stream ordering 1 month ago: it's %d",
self.stream_ordering_month_ago
)
logger.info("Searching for stream ordering 1 day ago")
self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn(
txn, self._clock.time_msec() - 24 * 60 * 60 * 1000
)
logger.info(
"Found stream ordering 1 day ago: it's %d",
self.stream_ordering_day_ago
)
def _find_first_stream_ordering_after_ts_txn(self, txn, ts):
"""
Find the stream_ordering of the first event that was received after
a given timestamp. This is relatively slow as there is no index on
received_ts but we can then use this to delete push actions before
this.
received_ts must necessarily be in the same order as stream_ordering
and stream_ordering is indexed, so we manually binary search using
stream_ordering
"""
txn.execute("SELECT MAX(stream_ordering) FROM events")
max_stream_ordering = txn.fetchone()[0]
if max_stream_ordering is None:
return 0
range_start = 0
range_end = max_stream_ordering
sql = (
"SELECT received_ts FROM events"
" WHERE stream_ordering > ?"
" ORDER BY stream_ordering"
" LIMIT 1"
)
while range_end - range_start > 1:
middle = int((range_end + range_start) / 2)
txn.execute(sql, (middle,))
middle_ts = txn.fetchone()[0]
if ts > middle_ts:
range_start = middle
else:
range_end = middle
return range_end
@defer.inlineCallbacks @defer.inlineCallbacks
def _rotate_notifs(self): def _rotate_notifs(self):
if self._doing_notif_rotation or self.stream_ordering_day_ago is None: if self._doing_notif_rotation or self.stream_ordering_day_ago is None:

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,395 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
from twisted.internet import defer, reactor
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from synapse.util.logcontext import (
preserve_fn, PreserveLoggingContext, make_deferred_yieldable
)
from synapse.util.metrics import Measure
from synapse.api.errors import SynapseError
from collections import namedtuple
import logging
import ujson as json
# these are only included to make the type annotations work
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
logger = logging.getLogger(__name__)
# These values are used in the `enqueus_event` and `_do_fetch` methods to
# control how we batch/bulk fetch events from the database.
# The values are plucked out of thing air to make initial sync run faster
# on jki.re
# TODO: Make these configurable.
EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events
EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False,
allow_none=False):
"""Get an event from the database by event_id.
Args:
event_id (str): The event_id of the event to fetch
check_redacted (bool): If True, check if event has been redacted
and redact it.
get_prev_content (bool): If True and event is a state event,
include the previous states content in the unsigned field.
allow_rejected (bool): If True return rejected events.
allow_none (bool): If True, return None if no event found, if
False throw an exception.
Returns:
Deferred : A FrozenEvent.
"""
events = yield self._get_events(
[event_id],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
if not events and not allow_none:
raise SynapseError(404, "Could not find event %s" % (event_id,))
defer.returnValue(events[0] if events else None)
@defer.inlineCallbacks
def get_events(self, event_ids, check_redacted=True,
get_prev_content=False, allow_rejected=False):
"""Get events from the database
Args:
event_ids (list): The event_ids of the events to fetch
check_redacted (bool): If True, check if event has been redacted
and redact it.
get_prev_content (bool): If True and event is a state event,
include the previous states content in the unsigned field.
allow_rejected (bool): If True return rejected events.
Returns:
Deferred : Dict from event_id to event.
"""
events = yield self._get_events(
event_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
defer.returnValue({e.event_id: e for e in events})
@defer.inlineCallbacks
def _get_events(self, event_ids, check_redacted=True,
get_prev_content=False, allow_rejected=False):
if not event_ids:
defer.returnValue([])
event_id_list = event_ids
event_ids = set(event_ids)
event_entry_map = self._get_events_from_cache(
event_ids,
allow_rejected=allow_rejected,
)
missing_events_ids = [e for e in event_ids if e not in event_entry_map]
if missing_events_ids:
missing_events = yield self._enqueue_events(
missing_events_ids,
check_redacted=check_redacted,
allow_rejected=allow_rejected,
)
event_entry_map.update(missing_events)
events = []
for event_id in event_id_list:
entry = event_entry_map.get(event_id, None)
if not entry:
continue
if allow_rejected or not entry.event.rejected_reason:
if check_redacted and entry.redacted_event:
event = entry.redacted_event
else:
event = entry.event
events.append(event)
if get_prev_content:
if "replaces_state" in event.unsigned:
prev = yield self.get_event(
event.unsigned["replaces_state"],
get_prev_content=False,
allow_none=True,
)
if prev:
event.unsigned = dict(event.unsigned)
event.unsigned["prev_content"] = prev.content
event.unsigned["prev_sender"] = prev.sender
defer.returnValue(events)
def _invalidate_get_event_cache(self, event_id):
self._get_event_cache.invalidate((event_id,))
def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
"""Fetch events from the caches
Args:
events (list(str)): list of event_ids to fetch
allow_rejected (bool): Whether to teturn events that were rejected
update_metrics (bool): Whether to update the cache hit ratio metrics
Returns:
dict of event_id -> _EventCacheEntry for each event_id in cache. If
allow_rejected is `False` then there will still be an entry but it
will be `None`
"""
event_map = {}
for event_id in events:
ret = self._get_event_cache.get(
(event_id,), None,
update_metrics=update_metrics,
)
if not ret:
continue
if allow_rejected or not ret.event.rejected_reason:
event_map[event_id] = ret
else:
event_map[event_id] = None
return event_map
def _do_fetch(self, conn):
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
event_list = []
i = 0
while True:
try:
with self._event_fetch_lock:
event_list = self._event_fetch_list
self._event_fetch_list = []
if not event_list:
single_threaded = self.database_engine.single_threaded
if single_threaded or i > EVENT_QUEUE_ITERATIONS:
self._event_fetch_ongoing -= 1
return
else:
self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
i += 1
continue
i = 0
event_id_lists = zip(*event_list)[0]
event_ids = [
item for sublist in event_id_lists for item in sublist
]
rows = self._new_transaction(
conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
)
row_dict = {
r["event_id"]: r
for r in rows
}
# We only want to resolve deferreds from the main thread
def fire(lst, res):
for ids, d in lst:
if not d.called:
try:
with PreserveLoggingContext():
d.callback([
res[i]
for i in ids
if i in res
])
except Exception:
logger.exception("Failed to callback")
with PreserveLoggingContext():
reactor.callFromThread(fire, event_list, row_dict)
except Exception as e:
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
def fire(evs):
for _, d in evs:
if not d.called:
with PreserveLoggingContext():
d.errback(e)
if event_list:
with PreserveLoggingContext():
reactor.callFromThread(fire, event_list)
@defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
"""
if not events:
defer.returnValue({})
events_d = defer.Deferred()
with self._event_fetch_lock:
self._event_fetch_list.append(
(events, events_d)
)
self._event_fetch_lock.notify()
if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
self._event_fetch_ongoing += 1
should_start = True
else:
should_start = False
if should_start:
with PreserveLoggingContext():
self.runWithConnection(
self._do_fetch
)
logger.debug("Loading %d events", len(events))
with PreserveLoggingContext():
rows = yield events_d
logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]]
res = yield make_deferred_yieldable(defer.gatherResults(
[
preserve_fn(self._get_event_from_row)(
row["internal_metadata"], row["json"], row["redacts"],
rejected_reason=row["rejects"],
)
for row in rows
],
consumeErrors=True
))
defer.returnValue({
e.event.event_id: e
for e in res if e
})
def _fetch_event_rows(self, txn, events):
rows = []
N = 200
for i in range(1 + len(events) / N):
evs = events[i * N:(i + 1) * N]
if not evs:
break
sql = (
"SELECT "
" e.event_id as event_id, "
" e.internal_metadata,"
" e.json,"
" r.redacts as redacts,"
" rej.event_id as rejects "
" FROM event_json as e"
" LEFT JOIN rejections as rej USING (event_id)"
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
" WHERE e.event_id IN (%s)"
) % (",".join(["?"] * len(evs)),)
txn.execute(sql, evs)
rows.extend(self.cursor_to_dict(txn))
return rows
@defer.inlineCallbacks
def _get_event_from_row(self, internal_metadata, js, redacted,
rejected_reason=None):
with Measure(self._clock, "_get_event_from_row"):
d = json.loads(js)
internal_metadata = json.loads(internal_metadata)
if rejected_reason:
rejected_reason = yield self._simple_select_one_onecol(
table="rejections",
keyvalues={"event_id": rejected_reason},
retcol="reason",
desc="_get_event_from_row_rejected_reason",
)
original_ev = FrozenEvent(
d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
redacted_event = None
if redacted:
redacted_event = prune_event(original_ev)
redaction_id = yield self._simple_select_one_onecol(
table="redactions",
keyvalues={"redacts": redacted_event.event_id},
retcol="event_id",
desc="_get_event_from_row_redactions",
)
redacted_event.unsigned["redacted_by"] = redaction_id
# Get the redaction event.
because = yield self.get_event(
redaction_id,
check_redacted=False,
allow_none=True,
)
if because:
# It's fine to do add the event directly, since get_pdu_json
# will serialise this field correctly
redacted_event.unsigned["redacted_because"] = because
cache_entry = _EventCacheEntry(
event=original_ev,
redacted_event=redacted_event,
)
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
defer.returnValue(cache_entry)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,11 +15,17 @@
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.storage.appservice import ApplicationServiceWorkerStore
from synapse.storage.pusher import PusherWorkerStore
from synapse.storage.receipts import ReceiptsWorkerStore
from synapse.storage.roommember import RoomMemberWorkerStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.push.baserules import list_with_base_rules from synapse.push.baserules import list_with_base_rules
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from twisted.internet import defer from twisted.internet import defer
import abc
import logging import logging
import simplejson as json import simplejson as json
@ -48,7 +55,43 @@ def _load_rules(rawrules, enabled_map):
return rules return rules
class PushRuleStore(SQLBaseStore): class PushRulesWorkerStore(ApplicationServiceWorkerStore,
ReceiptsWorkerStore,
PusherWorkerStore,
RoomMemberWorkerStore,
SQLBaseStore):
"""This is an abstract base class where subclasses must implement
`get_max_push_rules_stream_id` which can be called in the initializer.
"""
# This ABCMeta metaclass ensures that we cannot be instantiated without
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
def __init__(self, db_conn, hs):
super(PushRulesWorkerStore, self).__init__(db_conn, hs)
push_rules_prefill, push_rules_id = self._get_cache_dict(
db_conn, "push_rules_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self.get_max_push_rules_stream_id(),
)
self.push_rules_stream_cache = StreamChangeCache(
"PushRulesStreamChangeCache", push_rules_id,
prefilled_cache=push_rules_prefill,
)
@abc.abstractmethod
def get_max_push_rules_stream_id(self):
"""Get the position of the push rules stream.
Returns:
int
"""
raise NotImplementedError()
@cachedInlineCallbacks(max_entries=5000) @cachedInlineCallbacks(max_entries=5000)
def get_push_rules_for_user(self, user_id): def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list( rows = yield self._simple_select_list(
@ -89,6 +132,22 @@ class PushRuleStore(SQLBaseStore):
r['rule_id']: False if r['enabled'] == 0 else True for r in results r['rule_id']: False if r['enabled'] == 0 else True for r in results
}) })
def have_push_rules_changed_for_user(self, user_id, last_id):
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
return defer.succeed(False)
else:
def have_push_rules_changed_txn(txn):
sql = (
"SELECT COUNT(stream_id) FROM push_rules_stream"
" WHERE user_id = ? AND ? < stream_id"
)
txn.execute(sql, (user_id, last_id))
count, = txn.fetchone()
return bool(count)
return self.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
)
@cachedList(cached_method_name="get_push_rules_for_user", @cachedList(cached_method_name="get_push_rules_for_user",
list_name="user_ids", num_args=1, inlineCallbacks=True) list_name="user_ids", num_args=1, inlineCallbacks=True)
def bulk_get_push_rules(self, user_ids): def bulk_get_push_rules(self, user_ids):
@ -228,6 +287,8 @@ class PushRuleStore(SQLBaseStore):
results.setdefault(row['user_name'], {})[row['rule_id']] = enabled results.setdefault(row['user_name'], {})[row['rule_id']] = enabled
defer.returnValue(results) defer.returnValue(results)
class PushRuleStore(PushRulesWorkerStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def add_push_rule( def add_push_rule(
self, user_id, rule_id, priority_class, conditions, actions, self, user_id, rule_id, priority_class, conditions, actions,
@ -526,21 +587,8 @@ class PushRuleStore(SQLBaseStore):
room stream ordering it corresponds to.""" room stream ordering it corresponds to."""
return self._push_rules_stream_id_gen.get_current_token() return self._push_rules_stream_id_gen.get_current_token()
def have_push_rules_changed_for_user(self, user_id, last_id): def get_max_push_rules_stream_id(self):
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): return self.get_push_rules_stream_token()[0]
return defer.succeed(False)
else:
def have_push_rules_changed_txn(txn):
sql = (
"SELECT COUNT(stream_id) FROM push_rules_stream"
" WHERE user_id = ? AND ? < stream_id"
)
txn.execute(sql, (user_id, last_id))
count, = txn.fetchone()
return bool(count)
return self.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
)
class RuleNotFoundException(Exception): class RuleNotFoundException(Exception):

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -27,7 +28,7 @@ import types
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PusherStore(SQLBaseStore): class PusherWorkerStore(SQLBaseStore):
def _decode_pushers_rows(self, rows): def _decode_pushers_rows(self, rows):
for r in rows: for r in rows:
dataJson = r['data'] dataJson = r['data']
@ -102,9 +103,6 @@ class PusherStore(SQLBaseStore):
rows = yield self.runInteraction("get_all_pushers", get_pushers) rows = yield self.runInteraction("get_all_pushers", get_pushers)
defer.returnValue(rows) defer.returnValue(rows)
def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token()
def get_all_updated_pushers(self, last_id, current_id, limit): def get_all_updated_pushers(self, last_id, current_id, limit):
if last_id == current_id: if last_id == current_id:
return defer.succeed(([], [])) return defer.succeed(([], []))
@ -198,6 +196,11 @@ class PusherStore(SQLBaseStore):
defer.returnValue(result) defer.returnValue(result)
class PusherStore(PusherWorkerStore):
def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token()
@defer.inlineCallbacks @defer.inlineCallbacks
def add_pusher(self, user_id, access_token, kind, app_id, def add_pusher(self, user_id, access_token, kind, app_id,
app_display_name, device_display_name, app_display_name, device_display_name,
@ -230,14 +233,18 @@ class PusherStore(SQLBaseStore):
) )
if newly_inserted: if newly_inserted:
# get_if_user_has_pusher only cares if the user has self.runInteraction(
# at least *one* pusher. "add_pusher",
self.get_if_user_has_pusher.invalidate(user_id,) self._invalidate_cache_and_stream,
self.get_if_user_has_pusher, (user_id,)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id): def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
def delete_pusher_txn(txn, stream_id): def delete_pusher_txn(txn, stream_id):
txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,)) self._invalidate_cache_and_stream(
txn, self.get_if_user_has_pusher, (user_id,)
)
self._simple_delete_one_txn( self._simple_delete_one_txn(
txn, txn,

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,11 +15,13 @@
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore from ._base import SQLBaseStore
from .util.id_generators import StreamIdGenerator
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from twisted.internet import defer from twisted.internet import defer
import abc
import logging import logging
import ujson as json import ujson as json
@ -26,39 +29,36 @@ import ujson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ReceiptsStore(SQLBaseStore): class ReceiptsWorkerStore(SQLBaseStore):
"""This is an abstract base class where subclasses must implement
`get_max_receipt_stream_id` which can be called in the initializer.
"""
# This ABCMeta metaclass ensures that we cannot be instantiated without
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(ReceiptsStore, self).__init__(db_conn, hs) super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
self._receipts_stream_cache = StreamChangeCache( self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
) )
@abc.abstractmethod
def get_max_receipt_stream_id(self):
"""Get the current max stream ID for receipts stream
Returns:
int
"""
raise NotImplementedError()
@cachedInlineCallbacks() @cachedInlineCallbacks()
def get_users_with_read_receipts_in_room(self, room_id): def get_users_with_read_receipts_in_room(self, room_id):
receipts = yield self.get_receipts_for_room(room_id, "m.read") receipts = yield self.get_receipts_for_room(room_id, "m.read")
defer.returnValue(set(r['user_id'] for r in receipts)) defer.returnValue(set(r['user_id'] for r in receipts))
def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
user_id):
if receipt_type != "m.read":
return
# Returns an ObservableDeferred
res = self.get_users_with_read_receipts_in_room.cache.get(
room_id, None, update_metrics=False,
)
if res:
if isinstance(res, defer.Deferred) and res.called:
res = res.result
if user_id in res:
# We'd only be adding to the set, so no point invalidating if the
# user is already there
return
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
@cached(num_args=2) @cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type): def get_receipts_for_room(self, room_id, receipt_type):
return self._simple_select_list( return self._simple_select_list(
@ -270,6 +270,59 @@ class ReceiptsStore(SQLBaseStore):
} }
defer.returnValue(results) defer.returnValue(results)
def get_all_updated_receipts(self, last_id, current_id, limit=None):
if last_id == current_id:
return defer.succeed([])
def get_all_updated_receipts_txn(txn):
sql = (
"SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
" FROM receipts_linearized"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
)
args = [last_id, current_id]
if limit is not None:
sql += " LIMIT ?"
args.append(limit)
txn.execute(sql, args)
return txn.fetchall()
return self.runInteraction(
"get_all_updated_receipts", get_all_updated_receipts_txn
)
def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
user_id):
if receipt_type != "m.read":
return
# Returns an ObservableDeferred
res = self.get_users_with_read_receipts_in_room.cache.get(
room_id, None, update_metrics=False,
)
if res:
if isinstance(res, defer.Deferred) and res.called:
res = res.result
if user_id in res:
# We'd only be adding to the set, so no point invalidating if the
# user is already there
return
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
class ReceiptsStore(ReceiptsWorkerStore):
def __init__(self, db_conn, hs):
# We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
)
super(ReceiptsStore, self).__init__(db_conn, hs)
def get_max_receipt_stream_id(self): def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token() return self._receipts_id_gen.get_current_token()
@ -457,25 +510,3 @@ class ReceiptsStore(SQLBaseStore):
"data": json.dumps(data), "data": json.dumps(data),
} }
) )
def get_all_updated_receipts(self, last_id, current_id, limit=None):
if last_id == current_id:
return defer.succeed([])
def get_all_updated_receipts_txn(txn):
sql = (
"SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
" FROM receipts_linearized"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
)
args = [last_id, current_id]
if limit is not None:
sql += " LIMIT ?"
args.append(limit)
txn.execute(sql, args)
return txn.fetchall()
return self.runInteraction(
"get_all_updated_receipts", get_all_updated_receipts_txn
)

View file

@ -19,10 +19,70 @@ from twisted.internet import defer
from synapse.api.errors import StoreError, Codes from synapse.api.errors import StoreError, Codes
from synapse.storage import background_updates from synapse.storage import background_updates
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
class RegistrationStore(background_updates.BackgroundUpdateStore): class RegistrationWorkerStore(SQLBaseStore):
@cached()
def get_user_by_id(self, user_id):
return self._simple_select_one(
table="users",
keyvalues={
"name": user_id,
},
retcols=["name", "password_hash", "is_guest"],
allow_none=True,
desc="get_user_by_id",
)
@cached()
def get_user_by_access_token(self, token):
"""Get a user from the given access token.
Args:
token (str): The access token of a user.
Returns:
defer.Deferred: None, if the token did not match, otherwise dict
including the keys `name`, `is_guest`, `device_id`, `token_id`.
"""
return self.runInteraction(
"get_user_by_access_token",
self._query_for_auth,
token
)
@defer.inlineCallbacks
def is_server_admin(self, user):
res = yield self._simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
retcol="admin",
allow_none=True,
desc="is_server_admin",
)
defer.returnValue(res if res else False)
def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
" access_tokens.device_id"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"
)
txn.execute(sql, (token,))
rows = self.cursor_to_dict(txn)
if rows:
return rows[0]
return None
class RegistrationStore(RegistrationWorkerStore,
background_updates.BackgroundUpdateStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(RegistrationStore, self).__init__(db_conn, hs) super(RegistrationStore, self).__init__(db_conn, hs)
@ -39,12 +99,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
# we no longer use refresh tokens, but it's possible that some people # we no longer use refresh tokens, but it's possible that some people
# might have a background update queued to build this index. Just # might have a background update queued to build this index. Just
# clear the background update. # clear the background update.
@defer.inlineCallbacks self.register_noop_background_update("refresh_tokens_device_index")
def noop_update(progress, batch_size):
yield self._end_background_update("refresh_tokens_device_index")
defer.returnValue(1)
self.register_background_update_handler(
"refresh_tokens_device_index", noop_update)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id=None): def add_access_token_to_user(self, user_id, token, device_id=None):
@ -192,18 +247,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
) )
txn.call_after(self.is_guest.invalidate, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,))
@cached()
def get_user_by_id(self, user_id):
return self._simple_select_one(
table="users",
keyvalues={
"name": user_id,
},
retcols=["name", "password_hash", "is_guest"],
allow_none=True,
desc="get_user_by_id",
)
def get_users_by_id_case_insensitive(self, user_id): def get_users_by_id_case_insensitive(self, user_id):
"""Gets users that match user_id case insensitively. """Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash. Returns a mapping of user_id -> password_hash.
@ -309,34 +352,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
return self.runInteraction("delete_access_token", f) return self.runInteraction("delete_access_token", f)
@cached()
def get_user_by_access_token(self, token):
"""Get a user from the given access token.
Args:
token (str): The access token of a user.
Returns:
defer.Deferred: None, if the token did not match, otherwise dict
including the keys `name`, `is_guest`, `device_id`, `token_id`.
"""
return self.runInteraction(
"get_user_by_access_token",
self._query_for_auth,
token
)
@defer.inlineCallbacks
def is_server_admin(self, user):
res = yield self._simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
retcol="admin",
allow_none=True,
desc="is_server_admin",
)
defer.returnValue(res if res else False)
@cachedInlineCallbacks() @cachedInlineCallbacks()
def is_guest(self, user_id): def is_guest(self, user_id):
res = yield self._simple_select_one_onecol( res = yield self._simple_select_one_onecol(
@ -349,22 +364,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
defer.returnValue(res if res else False) defer.returnValue(res if res else False)
def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
" access_tokens.device_id"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"
)
txn.execute(sql, (token,))
rows = self.cursor_to_dict(txn)
if rows:
return rows[0]
return None
@defer.inlineCallbacks @defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at): def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
yield self._simple_upsert("user_threepids", { yield self._simple_upsert("user_threepids", {

View file

@ -16,11 +16,10 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.search import SearchStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from ._base import SQLBaseStore
from .engines import PostgresEngine, Sqlite3Engine
import collections import collections
import logging import logging
import ujson as json import ujson as json
@ -40,7 +39,126 @@ RatelimitOverride = collections.namedtuple(
) )
class RoomStore(SQLBaseStore): class RoomWorkerStore(SQLBaseStore):
def get_public_room_ids(self):
return self._simple_select_onecol(
table="rooms",
keyvalues={
"is_public": True,
},
retcol="room_id",
desc="get_public_room_ids",
)
@cached(num_args=2, max_entries=100)
def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
"""Get pulbic rooms for a particular list, or across all lists.
Args:
stream_id (int)
network_tuple (ThirdPartyInstanceID): The list to use (None, None)
means the main list, None means all lsits.
"""
return self.runInteraction(
"get_public_room_ids_at_stream_id",
self.get_public_room_ids_at_stream_id_txn,
stream_id, network_tuple=network_tuple
)
def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
network_tuple):
return {
rm
for rm, vis in self.get_published_at_stream_id_txn(
txn, stream_id, network_tuple=network_tuple
).items()
if vis
}
def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
if network_tuple:
# We want to get from a particular list. No aggregation required.
sql = ("""
SELECT room_id, visibility FROM public_room_list_stream
INNER JOIN (
SELECT room_id, max(stream_id) AS stream_id
FROM public_room_list_stream
WHERE stream_id <= ? %s
GROUP BY room_id
) grouped USING (room_id, stream_id)
""")
if network_tuple.appservice_id is not None:
txn.execute(
sql % ("AND appservice_id = ? AND network_id = ?",),
(stream_id, network_tuple.appservice_id, network_tuple.network_id,)
)
else:
txn.execute(
sql % ("AND appservice_id IS NULL",),
(stream_id,)
)
return dict(txn)
else:
# We want to get from all lists, so we need to aggregate the results
logger.info("Executing full list")
sql = ("""
SELECT room_id, visibility
FROM public_room_list_stream
INNER JOIN (
SELECT
room_id, max(stream_id) AS stream_id, appservice_id,
network_id
FROM public_room_list_stream
WHERE stream_id <= ?
GROUP BY room_id, appservice_id, network_id
) grouped USING (room_id, stream_id)
""")
txn.execute(
sql,
(stream_id,)
)
results = {}
# A room is visible if its visible on any list.
for room_id, visibility in txn:
results[room_id] = bool(visibility) or results.get(room_id, False)
return results
def get_public_room_changes(self, prev_stream_id, new_stream_id,
network_tuple):
def get_public_room_changes_txn(txn):
then_rooms = self.get_public_room_ids_at_stream_id_txn(
txn, prev_stream_id, network_tuple
)
now_rooms_dict = self.get_published_at_stream_id_txn(
txn, new_stream_id, network_tuple
)
now_rooms_visible = set(
rm for rm, vis in now_rooms_dict.items() if vis
)
now_rooms_not_visible = set(
rm for rm, vis in now_rooms_dict.items() if not vis
)
newly_visible = now_rooms_visible - then_rooms
newly_unpublished = now_rooms_not_visible & then_rooms
return newly_visible, newly_unpublished
return self.runInteraction(
"get_public_room_changes", get_public_room_changes_txn
)
class RoomStore(RoomWorkerStore, SearchStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def store_room(self, room_id, room_creator_user_id, is_public): def store_room(self, room_id, room_creator_user_id, is_public):
@ -227,16 +345,6 @@ class RoomStore(SQLBaseStore):
) )
self.hs.get_notifier().on_new_replication_data() self.hs.get_notifier().on_new_replication_data()
def get_public_room_ids(self):
return self._simple_select_onecol(
table="rooms",
keyvalues={
"is_public": True,
},
retcol="room_id",
desc="get_public_room_ids",
)
def get_room_count(self): def get_room_count(self):
"""Retrieve a list of all rooms """Retrieve a list of all rooms
""" """
@ -263,8 +371,8 @@ class RoomStore(SQLBaseStore):
}, },
) )
self._store_event_search_txn( self.store_event_search_txn(
txn, event, "content.topic", event.content["topic"] txn, event, "content.topic", event.content["topic"],
) )
def _store_room_name_txn(self, txn, event): def _store_room_name_txn(self, txn, event):
@ -279,14 +387,14 @@ class RoomStore(SQLBaseStore):
} }
) )
self._store_event_search_txn( self.store_event_search_txn(
txn, event, "content.name", event.content["name"] txn, event, "content.name", event.content["name"],
) )
def _store_room_message_txn(self, txn, event): def _store_room_message_txn(self, txn, event):
if hasattr(event, "content") and "body" in event.content: if hasattr(event, "content") and "body" in event.content:
self._store_event_search_txn( self.store_event_search_txn(
txn, event, "content.body", event.content["body"] txn, event, "content.body", event.content["body"],
) )
def _store_history_visibility_txn(self, txn, event): def _store_history_visibility_txn(self, txn, event):
@ -308,31 +416,6 @@ class RoomStore(SQLBaseStore):
event.content[key] event.content[key]
)) ))
def _store_event_search_txn(self, txn, event, key, value):
if isinstance(self.database_engine, PostgresEngine):
sql = (
"INSERT INTO event_search"
" (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
)
txn.execute(
sql,
(
event.event_id, event.room_id, key, value,
event.internal_metadata.stream_ordering,
event.origin_server_ts,
)
)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
txn.execute(sql, (event.event_id, event.room_id, key, value,))
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
def add_event_report(self, room_id, event_id, user_id, reason, content, def add_event_report(self, room_id, event_id, user_id, reason, content,
received_ts): received_ts):
next_id = self._event_reports_id_gen.get_next() next_id = self._event_reports_id_gen.get_next()
@ -353,113 +436,6 @@ class RoomStore(SQLBaseStore):
def get_current_public_room_stream_id(self): def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token() return self._public_room_id_gen.get_current_token()
@cached(num_args=2, max_entries=100)
def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
"""Get pulbic rooms for a particular list, or across all lists.
Args:
stream_id (int)
network_tuple (ThirdPartyInstanceID): The list to use (None, None)
means the main list, None means all lsits.
"""
return self.runInteraction(
"get_public_room_ids_at_stream_id",
self.get_public_room_ids_at_stream_id_txn,
stream_id, network_tuple=network_tuple
)
def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
network_tuple):
return {
rm
for rm, vis in self.get_published_at_stream_id_txn(
txn, stream_id, network_tuple=network_tuple
).items()
if vis
}
def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
if network_tuple:
# We want to get from a particular list. No aggregation required.
sql = ("""
SELECT room_id, visibility FROM public_room_list_stream
INNER JOIN (
SELECT room_id, max(stream_id) AS stream_id
FROM public_room_list_stream
WHERE stream_id <= ? %s
GROUP BY room_id
) grouped USING (room_id, stream_id)
""")
if network_tuple.appservice_id is not None:
txn.execute(
sql % ("AND appservice_id = ? AND network_id = ?",),
(stream_id, network_tuple.appservice_id, network_tuple.network_id,)
)
else:
txn.execute(
sql % ("AND appservice_id IS NULL",),
(stream_id,)
)
return dict(txn)
else:
# We want to get from all lists, so we need to aggregate the results
logger.info("Executing full list")
sql = ("""
SELECT room_id, visibility
FROM public_room_list_stream
INNER JOIN (
SELECT
room_id, max(stream_id) AS stream_id, appservice_id,
network_id
FROM public_room_list_stream
WHERE stream_id <= ?
GROUP BY room_id, appservice_id, network_id
) grouped USING (room_id, stream_id)
""")
txn.execute(
sql,
(stream_id,)
)
results = {}
# A room is visible if its visible on any list.
for room_id, visibility in txn:
results[room_id] = bool(visibility) or results.get(room_id, False)
return results
def get_public_room_changes(self, prev_stream_id, new_stream_id,
network_tuple):
def get_public_room_changes_txn(txn):
then_rooms = self.get_public_room_ids_at_stream_id_txn(
txn, prev_stream_id, network_tuple
)
now_rooms_dict = self.get_published_at_stream_id_txn(
txn, new_stream_id, network_tuple
)
now_rooms_visible = set(
rm for rm, vis in now_rooms_dict.items() if vis
)
now_rooms_not_visible = set(
rm for rm, vis in now_rooms_dict.items() if not vis
)
newly_visible = now_rooms_visible - then_rooms
newly_unpublished = now_rooms_not_visible & then_rooms
return newly_visible, newly_unpublished
return self.runInteraction(
"get_public_room_changes", get_public_room_changes_txn
)
def get_all_new_public_rooms(self, prev_id, current_id, limit): def get_all_new_public_rooms(self, prev_id, current_id, limit):
def get_all_new_public_rooms(txn): def get_all_new_public_rooms(txn):
sql = (""" sql = ("""
@ -533,16 +509,84 @@ class RoomStore(SQLBaseStore):
) )
self.is_room_blocked.invalidate((room_id,)) self.is_room_blocked.invalidate((room_id,))
def get_media_mxcs_in_room(self, room_id):
"""Retrieves all the local and remote media MXC URIs in a given room
Args:
room_id (str)
Returns:
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
"""
def _get_media_mxcs_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
local_media_mxcs = []
remote_media_mxcs = []
# Convert the IDs to MXC URIs
for media_id in local_mxcs:
local_media_mxcs.append("mxc://%s/%s" % (self.hostname, media_id))
for hostname, media_id in remote_mxcs:
remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
return local_media_mxcs, remote_media_mxcs
return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
def quarantine_media_ids_in_room(self, room_id, quarantined_by): def quarantine_media_ids_in_room(self, room_id, quarantined_by):
"""For a room loops through all events with media and quarantines """For a room loops through all events with media and quarantines
the associated media the associated media
""" """
def _get_media_ids_in_room(txn): def _quarantine_media_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
total_media_quarantined = 0
# Now update all the tables to set the quarantined_by flag
txn.executemany("""
UPDATE local_media_repository
SET quarantined_by = ?
WHERE media_id = ?
""", ((quarantined_by, media_id) for media_id in local_mxcs))
txn.executemany(
"""
UPDATE remote_media_cache
SET quarantined_by = ?
WHERE media_origin = ? AND media_id = ?
""",
(
(quarantined_by, origin, media_id)
for origin, media_id in remote_mxcs
)
)
total_media_quarantined += len(local_mxcs)
total_media_quarantined += len(remote_mxcs)
return total_media_quarantined
return self.runInteraction(
"quarantine_media_in_room",
_quarantine_media_in_room_txn,
)
def _get_media_mxcs_in_room_txn(self, txn, room_id):
"""Retrieves all the local and remote media MXC URIs in a given room
Args:
txn (cursor)
room_id (str)
Returns:
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
"""
mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
next_token = self.get_current_events_token() + 1 next_token = self.get_current_events_token() + 1
local_media_mxcs = []
total_media_quarantined = 0 remote_media_mxcs = []
while next_token: while next_token:
sql = """ sql = """
@ -556,8 +600,6 @@ class RoomStore(SQLBaseStore):
txn.execute(sql, (room_id, next_token, True, False, 100)) txn.execute(sql, (room_id, next_token, True, False, 100))
next_token = None next_token = None
local_media_mxcs = []
remote_media_mxcs = []
for stream_ordering, content_json in txn: for stream_ordering, content_json in txn:
next_token = stream_ordering next_token = stream_ordering
content = json.loads(content_json) content = json.loads(content_json)
@ -577,29 +619,4 @@ class RoomStore(SQLBaseStore):
else: else:
remote_media_mxcs.append((hostname, media_id)) remote_media_mxcs.append((hostname, media_id))
# Now update all the tables to set the quarantined_by flag return local_media_mxcs, remote_media_mxcs
txn.executemany("""
UPDATE local_media_repository
SET quarantined_by = ?
WHERE media_id = ?
""", ((quarantined_by, media_id) for media_id in local_media_mxcs))
txn.executemany(
"""
UPDATE remote_media_cache
SET quarantined_by = ?
WHERE media_origin AND media_id = ?
""",
(
(quarantined_by, origin, media_id)
for origin, media_id in remote_media_mxcs
)
)
total_media_quarantined += len(local_media_mxcs)
total_media_quarantined += len(remote_media_mxcs)
return total_media_quarantined
return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -17,7 +18,7 @@ from twisted.internet import defer
from collections import namedtuple from collections import namedtuple
from ._base import SQLBaseStore from synapse.storage.events import EventsWorkerStore
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.caches import intern_string from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@ -37,6 +38,11 @@ RoomsForUser = namedtuple(
("room_id", "sender", "membership", "event_id", "stream_ordering") ("room_id", "sender", "membership", "event_id", "stream_ordering")
) )
GetRoomsForUserWithStreamOrdering = namedtuple(
"_GetRoomsForUserWithStreamOrdering",
("room_id", "stream_ordering",)
)
# We store this using a namedtuple so that we save about 3x space over using a # We store this using a namedtuple so that we save about 3x space over using a
# dict. # dict.
@ -48,97 +54,7 @@ ProfileInfo = namedtuple(
_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
class RoomMemberStore(SQLBaseStore): class RoomMemberWorkerStore(EventsWorkerStore):
def __init__(self, db_conn, hs):
super(RoomMemberStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
)
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
self._simple_insert_many_txn(
txn,
table="room_memberships",
values=[
{
"event_id": event.event_id,
"user_id": event.state_key,
"sender": event.user_id,
"room_id": event.room_id,
"membership": event.membership,
"display_name": event.content.get("displayname", None),
"avatar_url": event.content.get("avatar_url", None),
}
for event in events
]
)
for event in events:
txn.call_after(
self._membership_stream_cache.entity_has_changed,
event.state_key, event.internal_metadata.stream_ordering
)
txn.call_after(
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
)
# We update the local_invites table only if the event is "current",
# i.e., its something that has just happened.
# The only current event that can also be an outlier is if its an
# invite that has come in across federation.
is_new_state = not backfilled and (
not event.internal_metadata.is_outlier()
or event.internal_metadata.is_invite_from_remote()
)
is_mine = self.hs.is_mine_id(event.state_key)
if is_new_state and is_mine:
if event.membership == Membership.INVITE:
self._simple_insert_txn(
txn,
table="local_invites",
values={
"event_id": event.event_id,
"invitee": event.state_key,
"inviter": event.sender,
"room_id": event.room_id,
"stream_id": event.internal_metadata.stream_ordering,
}
)
else:
sql = (
"UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
txn.execute(sql, (
event.internal_metadata.stream_ordering,
event.event_id,
event.room_id,
event.state_key,
))
@defer.inlineCallbacks
def locally_reject_invite(self, user_id, room_id):
sql = (
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
def f(txn, stream_ordering):
txn.execute(sql, (
stream_ordering,
True,
room_id,
user_id,
))
with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
@cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True) @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
def get_hosts_in_room(self, room_id, cache_context): def get_hosts_in_room(self, room_id, cache_context):
"""Returns the set of all hosts currently in the room """Returns the set of all hosts currently in the room
@ -270,12 +186,32 @@ class RoomMemberStore(SQLBaseStore):
return results return results
@cachedInlineCallbacks(max_entries=500000, iterable=True) @cachedInlineCallbacks(max_entries=500000, iterable=True)
def get_rooms_for_user(self, user_id): def get_rooms_for_user_with_stream_ordering(self, user_id):
"""Returns a set of room_ids the user is currently joined to """Returns a set of room_ids the user is currently joined to
Args:
user_id (str)
Returns:
Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
the rooms the user is in currently, along with the stream ordering
of the most recent join for that user and room.
""" """
rooms = yield self.get_rooms_for_user_where_membership_is( rooms = yield self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN], user_id, membership_list=[Membership.JOIN],
) )
defer.returnValue(frozenset(
GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
for r in rooms
))
@defer.inlineCallbacks
def get_rooms_for_user(self, user_id, on_invalidate=None):
"""Returns a set of room_ids the user is currently joined to
"""
rooms = yield self.get_rooms_for_user_with_stream_ordering(
user_id, on_invalidate=on_invalidate,
)
defer.returnValue(frozenset(r.room_id for r in rooms)) defer.returnValue(frozenset(r.room_id for r in rooms))
@cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
@ -295,89 +231,6 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(user_who_share_room) defer.returnValue(user_who_share_room)
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
sql = (
"UPDATE"
" room_memberships"
" SET"
" forgotten = 1"
" WHERE"
" user_id = ?"
" AND"
" room_id = ?"
)
txn.execute(sql, (user_id, room_id))
txn.call_after(self.was_forgotten_at.invalidate_all)
txn.call_after(self.did_forget.invalidate, (user_id, room_id))
self._invalidate_cache_and_stream(
txn, self.who_forgot_in_room, (room_id,)
)
return self.runInteraction("forget_membership", f)
@cachedInlineCallbacks(num_args=2)
def did_forget(self, user_id, room_id):
"""Returns whether user_id has elected to discard history for room_id.
Returns False if they have since re-joined."""
def f(txn):
sql = (
"SELECT"
" COUNT(*)"
" FROM"
" room_memberships"
" WHERE"
" user_id = ?"
" AND"
" room_id = ?"
" AND"
" forgotten = 0"
)
txn.execute(sql, (user_id, room_id))
rows = txn.fetchall()
return rows[0][0]
count = yield self.runInteraction("did_forget_membership", f)
defer.returnValue(count == 0)
@cachedInlineCallbacks(num_args=3)
def was_forgotten_at(self, user_id, room_id, event_id):
"""Returns whether user_id has elected to discard history for room_id at
event_id.
event_id must be a membership event."""
def f(txn):
sql = (
"SELECT"
" forgotten"
" FROM"
" room_memberships"
" WHERE"
" user_id = ?"
" AND"
" room_id = ?"
" AND"
" event_id = ?"
)
txn.execute(sql, (user_id, room_id, event_id))
rows = txn.fetchall()
return rows[0][0]
forgot = yield self.runInteraction("did_forget_membership_at", f)
defer.returnValue(forgot == 1)
@cached()
def who_forgot_in_room(self, room_id):
return self._simple_select_list(
table="room_memberships",
retcols=("user_id", "event_id"),
keyvalues={
"room_id": room_id,
"forgotten": 1,
},
desc="who_forgot"
)
def get_joined_users_from_context(self, event, context): def get_joined_users_from_context(self, event, context):
state_group = context.state_group state_group = context.state_group
if not state_group: if not state_group:
@ -600,6 +453,185 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(joined_hosts) defer.returnValue(joined_hosts)
@cached(max_entries=10000, iterable=True)
def _get_joined_hosts_cache(self, room_id):
return _JoinedHostsCache(self, room_id)
@cached()
def who_forgot_in_room(self, room_id):
return self._simple_select_list(
table="room_memberships",
retcols=("user_id", "event_id"),
keyvalues={
"room_id": room_id,
"forgotten": 1,
},
desc="who_forgot"
)
class RoomMemberStore(RoomMemberWorkerStore):
def __init__(self, db_conn, hs):
super(RoomMemberStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
)
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
self._simple_insert_many_txn(
txn,
table="room_memberships",
values=[
{
"event_id": event.event_id,
"user_id": event.state_key,
"sender": event.user_id,
"room_id": event.room_id,
"membership": event.membership,
"display_name": event.content.get("displayname", None),
"avatar_url": event.content.get("avatar_url", None),
}
for event in events
]
)
for event in events:
txn.call_after(
self._membership_stream_cache.entity_has_changed,
event.state_key, event.internal_metadata.stream_ordering
)
txn.call_after(
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
)
# We update the local_invites table only if the event is "current",
# i.e., its something that has just happened.
# The only current event that can also be an outlier is if its an
# invite that has come in across federation.
is_new_state = not backfilled and (
not event.internal_metadata.is_outlier()
or event.internal_metadata.is_invite_from_remote()
)
is_mine = self.hs.is_mine_id(event.state_key)
if is_new_state and is_mine:
if event.membership == Membership.INVITE:
self._simple_insert_txn(
txn,
table="local_invites",
values={
"event_id": event.event_id,
"invitee": event.state_key,
"inviter": event.sender,
"room_id": event.room_id,
"stream_id": event.internal_metadata.stream_ordering,
}
)
else:
sql = (
"UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
txn.execute(sql, (
event.internal_metadata.stream_ordering,
event.event_id,
event.room_id,
event.state_key,
))
@defer.inlineCallbacks
def locally_reject_invite(self, user_id, room_id):
sql = (
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
def f(txn, stream_ordering):
txn.execute(sql, (
stream_ordering,
True,
room_id,
user_id,
))
with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
sql = (
"UPDATE"
" room_memberships"
" SET"
" forgotten = 1"
" WHERE"
" user_id = ?"
" AND"
" room_id = ?"
)
txn.execute(sql, (user_id, room_id))
txn.call_after(self.was_forgotten_at.invalidate_all)
txn.call_after(self.did_forget.invalidate, (user_id, room_id))
self._invalidate_cache_and_stream(
txn, self.who_forgot_in_room, (room_id,)
)
return self.runInteraction("forget_membership", f)
@cachedInlineCallbacks(num_args=2)
def did_forget(self, user_id, room_id):
"""Returns whether user_id has elected to discard history for room_id.
Returns False if they have since re-joined."""
def f(txn):
sql = (
"SELECT"
" COUNT(*)"
" FROM"
" room_memberships"
" WHERE"
" user_id = ?"
" AND"
" room_id = ?"
" AND"
" forgotten = 0"
)
txn.execute(sql, (user_id, room_id))
rows = txn.fetchall()
return rows[0][0]
count = yield self.runInteraction("did_forget_membership", f)
defer.returnValue(count == 0)
@cachedInlineCallbacks(num_args=3)
def was_forgotten_at(self, user_id, room_id, event_id):
"""Returns whether user_id has elected to discard history for room_id at
event_id.
event_id must be a membership event."""
def f(txn):
sql = (
"SELECT"
" forgotten"
" FROM"
" room_memberships"
" WHERE"
" user_id = ?"
" AND"
" room_id = ?"
" AND"
" event_id = ?"
)
txn.execute(sql, (user_id, room_id, event_id))
rows = txn.fetchall()
return rows[0][0]
forgot = yield self.runInteraction("did_forget_membership_at", f)
defer.returnValue(forgot == 1)
@defer.inlineCallbacks @defer.inlineCallbacks
def _background_add_membership_profile(self, progress, batch_size): def _background_add_membership_profile(self, progress, batch_size):
target_min_stream_id = progress.get( target_min_stream_id = progress.get(
@ -675,10 +707,6 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(result) defer.returnValue(result)
@cached(max_entries=10000, iterable=True)
def _get_joined_hosts_cache(self, room_id):
return _JoinedHostsCache(self, room_id)
class _JoinedHostsCache(object): class _JoinedHostsCache(object):
"""Cache for joined hosts in a room that is optimised to handle updates """Cache for joined hosts in a room that is optimised to handle updates

View file

@ -13,5 +13,7 @@
* limitations under the License. * limitations under the License.
*/ */
INSERT into background_updates (update_name, progress_json) -- We no longer do this given we back it out again in schema 47
VALUES ('event_search_postgres_gist', '{}');
-- INSERT into background_updates (update_name, progress_json)
-- VALUES ('event_search_postgres_gist', '{}');

View file

@ -0,0 +1,17 @@
/* Copyright 2018 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT into background_updates (update_name, progress_json)
VALUES ('event_search_postgres_gin', '{}');

View file

@ -0,0 +1,28 @@
/* Copyright 2018 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Temporary staging area for push actions that have been calculated for an
-- event, but the event hasn't yet been persisted.
-- When the event is persisted the rows are moved over to the
-- event_push_actions table.
CREATE TABLE event_push_actions_staging (
event_id TEXT NOT NULL,
user_id TEXT NOT NULL,
actions TEXT NOT NULL,
notif SMALLINT NOT NULL,
highlight SMALLINT NOT NULL
);
CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging(event_id);

View file

@ -0,0 +1,37 @@
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.engines import PostgresEngine
def run_create(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
# if we already have some state groups, we want to start making new
# ones with a higher id.
cur.execute("SELECT max(id) FROM state_groups")
row = cur.fetchone()
if row[0] is None:
start_val = 1
else:
start_val = row[0] + 1
cur.execute(
"CREATE SEQUENCE state_group_id_seq START WITH %s",
(start_val, ),
)
def run_upgrade(*args, **kwargs):
pass

View file

@ -13,25 +13,32 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import namedtuple
import logging
import re
import ujson as json
from twisted.internet import defer from twisted.internet import defer
from .background_updates import BackgroundUpdateStore from .background_updates import BackgroundUpdateStore
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
import logging
import re
import ujson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SearchEntry = namedtuple('SearchEntry', [
'key', 'value', 'event_id', 'room_id', 'stream_ordering',
'origin_server_ts',
])
class SearchStore(BackgroundUpdateStore): class SearchStore(BackgroundUpdateStore):
EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_UPDATE_NAME = "event_search"
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SearchStore, self).__init__(db_conn, hs) super(SearchStore, self).__init__(db_conn, hs)
@ -42,23 +49,34 @@ class SearchStore(BackgroundUpdateStore):
self.EVENT_SEARCH_ORDER_UPDATE_NAME, self.EVENT_SEARCH_ORDER_UPDATE_NAME,
self._background_reindex_search_order self._background_reindex_search_order
) )
self.register_background_update_handler(
# we used to have a background update to turn the GIN index into a
# GIST one; we no longer do that (obviously) because we actually want
# a GIN index. However, it's possible that some people might still have
# the background update queued, so we register a handler to clear the
# background update.
self.register_noop_background_update(
self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME, self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME,
self._background_reindex_gist_search )
self.register_background_update_handler(
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME,
self._background_reindex_gin_search
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _background_reindex_search(self, progress, batch_size): def _background_reindex_search(self, progress, batch_size):
# we work through the events table from highest stream id to lowest
target_min_stream_id = progress["target_min_stream_id_inclusive"] target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"] max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0) rows_inserted = progress.get("rows_inserted", 0)
INSERT_CLUMP_SIZE = 1000
TYPES = ["m.room.name", "m.room.message", "m.room.topic"] TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
def reindex_search_txn(txn): def reindex_search_txn(txn):
sql = ( sql = (
"SELECT stream_ordering, event_id, room_id, type, content FROM events" "SELECT stream_ordering, event_id, room_id, type, content, "
" origin_server_ts FROM events"
" WHERE ? <= stream_ordering AND stream_ordering < ?" " WHERE ? <= stream_ordering AND stream_ordering < ?"
" AND (%s)" " AND (%s)"
" ORDER BY stream_ordering DESC" " ORDER BY stream_ordering DESC"
@ -67,6 +85,10 @@ class SearchStore(BackgroundUpdateStore):
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
# we could stream straight from the results into
# store_search_entries_txn with a generator function, but that
# would mean having two cursors open on the database at once.
# Instead we just build a list of results.
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
if not rows: if not rows:
return 0 return 0
@ -79,6 +101,8 @@ class SearchStore(BackgroundUpdateStore):
event_id = row["event_id"] event_id = row["event_id"]
room_id = row["room_id"] room_id = row["room_id"]
etype = row["type"] etype = row["type"]
stream_ordering = row["stream_ordering"]
origin_server_ts = row["origin_server_ts"]
try: try:
content = json.loads(row["content"]) content = json.loads(row["content"])
except Exception: except Exception:
@ -93,6 +117,8 @@ class SearchStore(BackgroundUpdateStore):
elif etype == "m.room.name": elif etype == "m.room.name":
key = "content.name" key = "content.name"
value = content["name"] value = content["name"]
else:
raise Exception("unexpected event type %s" % etype)
except (KeyError, AttributeError): except (KeyError, AttributeError):
# If the event is missing a necessary field then # If the event is missing a necessary field then
# skip over it. # skip over it.
@ -103,25 +129,16 @@ class SearchStore(BackgroundUpdateStore):
# then skip over it # then skip over it
continue continue
event_search_rows.append((event_id, room_id, key, value)) event_search_rows.append(SearchEntry(
key=key,
value=value,
event_id=event_id,
room_id=room_id,
stream_ordering=stream_ordering,
origin_server_ts=origin_server_ts,
))
if isinstance(self.database_engine, PostgresEngine): self.store_search_entries_txn(txn, event_search_rows)
sql = (
"INSERT INTO event_search (event_id, room_id, key, vector)"
" VALUES (?,?,?,to_tsvector('english', ?))"
)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
for index in range(0, len(event_search_rows), INSERT_CLUMP_SIZE):
clump = event_search_rows[index:index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump)
progress = { progress = {
"target_min_stream_id_inclusive": target_min_stream_id, "target_min_stream_id_inclusive": target_min_stream_id,
@ -145,25 +162,48 @@ class SearchStore(BackgroundUpdateStore):
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def _background_reindex_gist_search(self, progress, batch_size): def _background_reindex_gin_search(self, progress, batch_size):
"""This handles old synapses which used GIST indexes, if any;
converting them back to be GIN as per the actual schema.
"""
def create_index(conn): def create_index(conn):
conn.rollback() conn.rollback()
# we have to set autocommit, because postgres refuses to
# CREATE INDEX CONCURRENTLY without it.
conn.set_session(autocommit=True) conn.set_session(autocommit=True)
try:
c = conn.cursor() c = conn.cursor()
# if we skipped the conversion to GIST, we may already/still
# have an event_search_fts_idx; unfortunately postgres 9.4
# doesn't support CREATE INDEX IF EXISTS so we just catch the
# exception and ignore it.
import psycopg2
try:
c.execute( c.execute(
"CREATE INDEX CONCURRENTLY event_search_fts_idx_gist" "CREATE INDEX CONCURRENTLY event_search_fts_idx"
" ON event_search USING GIST (vector)" " ON event_search USING GIN (vector)"
)
except psycopg2.ProgrammingError as e:
logger.warn(
"Ignoring error %r when trying to switch from GIST to GIN",
e
) )
c.execute("DROP INDEX event_search_fts_idx") # we should now be able to delete the GIST index.
c.execute(
"DROP INDEX IF EXISTS event_search_fts_idx_gist"
)
finally:
conn.set_session(autocommit=False) conn.set_session(autocommit=False)
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
yield self.runWithConnection(create_index) yield self.runWithConnection(create_index)
yield self._end_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME) yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME)
defer.returnValue(1) defer.returnValue(1)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -242,6 +282,85 @@ class SearchStore(BackgroundUpdateStore):
defer.returnValue(num_rows) defer.returnValue(num_rows)
def store_event_search_txn(self, txn, event, key, value):
"""Add event to the search table
Args:
txn (cursor):
event (EventBase):
key (str):
value (str):
"""
self.store_search_entries_txn(
txn,
(SearchEntry(
key=key,
value=value,
event_id=event.event_id,
room_id=event.room_id,
stream_ordering=event.internal_metadata.stream_ordering,
origin_server_ts=event.origin_server_ts,
),),
)
def store_search_entries_txn(self, txn, entries):
"""Add entries to the search table
Args:
txn (cursor):
entries (iterable[SearchEntry]):
entries to be added to the table
"""
if isinstance(self.database_engine, PostgresEngine):
sql = (
"INSERT INTO event_search"
" (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
)
args = ((
entry.event_id, entry.room_id, entry.key, entry.value,
entry.stream_ordering, entry.origin_server_ts,
) for entry in entries)
# inserts to a GIN index are normally batched up into a pending
# list, and then all committed together once the list gets to a
# certain size. The trouble with that is that postgres (pre-9.5)
# uses work_mem to determine the length of the list, and work_mem
# is typically very large.
#
# We therefore reduce work_mem while we do the insert.
#
# (postgres 9.5 uses the separate gin_pending_list_limit setting,
# so doesn't suffer the same problem, but changing work_mem will
# be harmless)
#
# Note that we don't need to worry about restoring it on
# exception, because exceptions will cause the transaction to be
# rolled back, including the effects of the SET command.
#
# Also: we use SET rather than SET LOCAL because there's lots of
# other stuff going on in this transaction, which want to have the
# normal work_mem setting.
txn.execute("SET work_mem='256kB'")
txn.executemany(sql, args)
txn.execute("RESET work_mem")
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
args = ((
entry.event_id, entry.room_id, entry.key, entry.value,
) for entry in entries)
txn.executemany(sql, args)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
@defer.inlineCallbacks @defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys): def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys. """Performs a full text search over events with given keys.

View file

@ -22,12 +22,12 @@ from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
class SignatureStore(SQLBaseStore): class SignatureWorkerStore(SQLBaseStore):
"""Persistence for event signatures and hashes"""
@cached() @cached()
def get_event_reference_hash(self, event_id): def get_event_reference_hash(self, event_id):
return self._get_event_reference_hashes_txn(event_id) # This is a dummy function to allow get_event_reference_hashes
# to use its cache
raise NotImplementedError()
@cachedList(cached_method_name="get_event_reference_hash", @cachedList(cached_method_name="get_event_reference_hash",
list_name="event_ids", num_args=1) list_name="event_ids", num_args=1)
@ -74,6 +74,10 @@ class SignatureStore(SQLBaseStore):
txn.execute(query, (event_id, )) txn.execute(query, (event_id, ))
return {k: v for k, v in txn} return {k: v for k, v in txn}
class SignatureStore(SignatureWorkerStore):
"""Persistence for event signatures and hashes"""
def _store_event_reference_hashes_txn(self, txn, events): def _store_event_reference_hashes_txn(self, txn, events):
"""Store a hash for a PDU """Store a hash for a PDU
Args: Args:

View file

@ -42,11 +42,8 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
return len(self.delta_ids) if self.delta_ids else 0 return len(self.delta_ids) if self.delta_ids else 0
class StateGroupReadStore(SQLBaseStore): class StateGroupWorkerStore(SQLBaseStore):
"""The read-only parts of StateGroupStore """The parts of StateGroupStore that can be called from workers.
None of these functions write to the state tables, so are suitable for
including in the SlavedStores.
""" """
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
@ -54,7 +51,7 @@ class StateGroupReadStore(SQLBaseStore):
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(StateGroupReadStore, self).__init__(db_conn, hs) super(StateGroupWorkerStore, self).__init__(db_conn, hs)
self._state_group_cache = DictionaryCache( self._state_group_cache = DictionaryCache(
"*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
@ -142,6 +139,20 @@ class StateGroupReadStore(SQLBaseStore):
defer.returnValue(group_to_state) defer.returnValue(group_to_state)
@defer.inlineCallbacks
def get_state_ids_for_group(self, state_group):
"""Get the state IDs for the given state group
Args:
state_group (int)
Returns:
Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
"""
group_to_state = yield self._get_state_for_groups((state_group,))
defer.returnValue(group_to_state[state_group])
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_groups(self, room_id, event_ids): def get_state_groups(self, room_id, event_ids):
""" Get the state groups for the given list of event_ids """ Get the state groups for the given list of event_ids
@ -549,116 +560,66 @@ class StateGroupReadStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
def store_state_group(self, event_id, room_id, prev_group, delta_ids,
current_state_ids):
"""Store a new set of state, returning a newly assigned state group.
class StateStore(StateGroupReadStore, BackgroundUpdateStore): Args:
""" Keeps track of the state at a given event. event_id (str): The event ID for which the state was calculated
room_id (str)
prev_group (int|None): A previous state group for the room, optional.
delta_ids (dict|None): The delta between state at `prev_group` and
`current_state_ids`, if `prev_group` was given. Same format as
`current_state_ids`.
current_state_ids (dict): The state to store. Map of (type, state_key)
to event_id.
This is done by the concept of `state groups`. Every event is a assigned Returns:
a state group (identified by an arbitrary string), which references a Deferred[int]: The state group ID
collection of state events. The current state of an event is then the
collection of state events referenced by the event's state group.
Hence, every change in the current state causes a new state group to be
generated. However, if no change happens (e.g., if we get a message event
with only one parent it inherits the state group from its parent.)
There are three tables:
* `state_groups`: Stores group name, first event with in the group and
room id.
* `event_to_state_groups`: Maps events to state groups.
* `state_groups_state`: Maps state group to state events.
""" """
def _store_state_group_txn(txn):
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" if current_state_ids is None:
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
def __init__(self, db_conn, hs):
super(StateStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state,
)
self.register_background_update_handler(
self.STATE_GROUP_INDEX_UPDATE_NAME,
self._background_index_state,
)
self.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
index_name="current_state_events_member_index",
table="current_state_events",
columns=["state_key"],
where_clause="type='m.room.member'",
)
def _have_persisted_state_group_txn(self, txn, state_group):
txn.execute(
"SELECT count(*) FROM state_groups WHERE id = ?",
(state_group,)
)
row = txn.fetchone()
return row and row[0]
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
continue
if context.current_state_ids is None:
# AFAIK, this can never happen # AFAIK, this can never happen
logger.error( raise Exception("current_state_ids cannot be None")
"Non-outlier event %s had current_state_ids==None",
event.event_id)
continue
# if the event was rejected, just give it the same state as its state_group = self.database_engine.get_next_state_group_id(txn)
# predecessor.
if context.rejected:
state_groups[event.event_id] = context.prev_group
continue
state_groups[event.event_id] = context.state_group
if self._have_persisted_state_group_txn(txn, context.state_group):
continue
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
table="state_groups", table="state_groups",
values={ values={
"id": context.state_group, "id": state_group,
"room_id": event.room_id, "room_id": room_id,
"event_id": event.event_id, "event_id": event_id,
}, },
) )
# We persist as a delta if we can, while also ensuring the chain # We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades. # of deltas isn't tooo long, as otherwise read performance degrades.
if context.prev_group: if prev_group:
is_in_db = self._simple_select_one_onecol_txn( is_in_db = self._simple_select_one_onecol_txn(
txn, txn,
table="state_groups", table="state_groups",
keyvalues={"id": context.prev_group}, keyvalues={"id": prev_group},
retcol="id", retcol="id",
allow_none=True, allow_none=True,
) )
if not is_in_db: if not is_in_db:
raise Exception( raise Exception(
"Trying to persist state with unpersisted prev_group: %r" "Trying to persist state with unpersisted prev_group: %r"
% (context.prev_group,) % (prev_group,)
) )
potential_hops = self._count_state_group_hops_txn( potential_hops = self._count_state_group_hops_txn(
txn, context.prev_group txn, prev_group
) )
if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS: if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
table="state_group_edges", table="state_group_edges",
values={ values={
"state_group": context.state_group, "state_group": state_group,
"prev_state_group": context.prev_group, "prev_state_group": prev_group,
}, },
) )
@ -667,13 +628,13 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
table="state_groups_state", table="state_groups_state",
values=[ values=[
{ {
"state_group": context.state_group, "state_group": state_group,
"room_id": event.room_id, "room_id": room_id,
"type": key[0], "type": key[0],
"state_key": key[1], "state_key": key[1],
"event_id": state_id, "event_id": state_id,
} }
for key, state_id in context.delta_ids.iteritems() for key, state_id in delta_ids.iteritems()
], ],
) )
else: else:
@ -682,13 +643,13 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
table="state_groups_state", table="state_groups_state",
values=[ values=[
{ {
"state_group": context.state_group, "state_group": state_group,
"room_id": event.room_id, "room_id": room_id,
"type": key[0], "type": key[0],
"state_key": key[1], "state_key": key[1],
"event_id": state_id, "event_id": state_id,
} }
for key, state_id in context.current_state_ids.iteritems() for key, state_id in current_state_ids.iteritems()
], ],
) )
@ -699,28 +660,14 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
txn.call_after( txn.call_after(
self._state_group_cache.update, self._state_group_cache.update,
self._state_group_cache.sequence, self._state_group_cache.sequence,
key=context.state_group, key=state_group,
value=dict(context.current_state_ids), value=dict(current_state_ids),
full=True, full=True,
) )
self._simple_insert_many_txn( return state_group
txn,
table="event_to_state_groups",
values=[
{
"state_group": state_group_id,
"event_id": event_id,
}
for event_id, state_group_id in state_groups.iteritems()
],
)
for event_id, state_group_id in state_groups.iteritems(): return self.runInteraction("store_state_group", _store_state_group_txn)
txn.call_after(
self._get_state_group_for_event.prefill,
(event_id,), state_group_id
)
def _count_state_group_hops_txn(self, txn, state_group): def _count_state_group_hops_txn(self, txn, state_group):
"""Given a state group, count how many hops there are in the tree. """Given a state group, count how many hops there are in the tree.
@ -763,8 +710,79 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
return count return count
def get_next_state_group(self):
return self._state_groups_id_gen.get_next() class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
""" Keeps track of the state at a given event.
This is done by the concept of `state groups`. Every event is a assigned
a state group (identified by an arbitrary string), which references a
collection of state events. The current state of an event is then the
collection of state events referenced by the event's state group.
Hence, every change in the current state causes a new state group to be
generated. However, if no change happens (e.g., if we get a message event
with only one parent it inherits the state group from its parent.)
There are three tables:
* `state_groups`: Stores group name, first event with in the group and
room id.
* `event_to_state_groups`: Maps events to state groups.
* `state_groups_state`: Maps state group to state events.
"""
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
def __init__(self, db_conn, hs):
super(StateStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state,
)
self.register_background_update_handler(
self.STATE_GROUP_INDEX_UPDATE_NAME,
self._background_index_state,
)
self.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
index_name="current_state_events_member_index",
table="current_state_events",
columns=["state_key"],
where_clause="type='m.room.member'",
)
def _store_event_state_mappings_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
continue
# if the event was rejected, just give it the same state as its
# predecessor.
if context.rejected:
state_groups[event.event_id] = context.prev_group
continue
state_groups[event.event_id] = context.state_group
self._simple_insert_many_txn(
txn,
table="event_to_state_groups",
values=[
{
"state_group": state_group_id,
"event_id": event_id,
}
for event_id, state_group_id in state_groups.iteritems()
],
)
for event_id, state_group_id in state_groups.iteritems():
txn.call_after(
self._get_state_group_for_event.prefill,
(event_id,), state_group_id
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _background_deduplicate_state(self, progress, batch_size): def _background_deduplicate_state(self, progress, batch_size):

Some files were not shown because too many files have changed in this diff Show more