Merge remote-tracking branch 'origin/develop' into 3218-official-prom

This commit is contained in:
Amber Brown 2018-05-28 18:57:23 +10:00
commit 754826a830
47 changed files with 686 additions and 197 deletions

View file

@ -1,3 +1,59 @@
Changes in synapse v0.30.0 (2018-05-24)
==========================================
'Server Notices' are a new feature introduced in Synapse 0.30. They provide a
channel whereby server administrators can send messages to users on the server.
They are used as part of communication of the server policies (see ``docs/consent_tracking.md``),
however the intention is that they may also find a use for features such
as "Message of the day".
This feature is specific to Synapse, but uses standard Matrix communication mechanisms,
so should work with any Matrix client. For more details see ``docs/server_notices.md``
Further Server Notices/Consent Tracking Support:
* Allow overriding the server_notices user's avatar (PR #3273)
* Use the localpart in the consent uri (PR #3272)
* Support for putting %(consent_uri)s in messages (PR #3271)
* Block attempts to send server notices to remote users (PR #3270)
* Docs on consent bits (PR #3268)
Changes in synapse v0.30.0-rc1 (2018-05-23)
==========================================
Server Notices/Consent Tracking Support:
* ConsentResource to gather policy consent from users (PR #3213)
* Move RoomCreationHandler out of synapse.handlers.Handlers (PR #3225)
* Infrastructure for a server notices room (PR #3232)
* Send users a server notice about consent (PR #3236)
* Reject attempts to send event before privacy consent is given (PR #3257)
* Add a 'has_consented' template var to consent forms (PR #3262)
* Fix dependency on jinja2 (PR #3263)
Features:
* Cohort analytics (PR #3163, #3241, #3251)
* Add lxml to docker image for web previews (PR #3239) Thanks to @ptman!
* Add in flight request metrics (PR #3252)
Changes:
* Remove unused `update_external_syncs` (PR #3233)
* Use stream rather depth ordering for push actions (PR #3212)
* Make purge_history operate on tokens (PR #3221)
* Don't support limitless pagination (PR #3265)
Bug Fixes:
* Fix logcontext resource usage tracking (PR #3258)
* Fix error in handling receipts (PR #3235)
* Stop the transaction cache caching failures (PR #3255)
Changes in synapse v0.29.1 (2018-05-17) Changes in synapse v0.29.1 (2018-05-17)
========================================== ==========================================
Changes: Changes:

160
docs/consent_tracking.md Normal file
View file

@ -0,0 +1,160 @@
Support in Synapse for tracking agreement to server terms and conditions
========================================================================
Synapse 0.30 introduces support for tracking whether users have agreed to the
terms and conditions set by the administrator of a server - and blocking access
to the server until they have.
There are several parts to this functionality; each requires some specific
configuration in `homeserver.yaml` to be enabled.
Note that various parts of the configuation and this document refer to the
"privacy policy": agreement with a privacy policy is one particular use of this
feature, but of course adminstrators can specify other terms and conditions
unrelated to "privacy" per se.
Collecting policy agreement from a user
---------------------------------------
Synapse can be configured to serve the user a simple policy form with an
"accept" button. Clicking "Accept" records the user's acceptance in the
database and shows a success page.
To enable this, first create templates for the policy and success pages.
These should be stored on the local filesystem.
These templates use the [Jinja2](http://jinja.pocoo.org) templating language,
and [docs/privacy_policy_templates](privacy_policy_templates) gives
examples of the sort of thing that can be done.
Note that the templates must be stored under a name giving the language of the
template - currently this must always be `en` (for "English");
internationalisation support is intended for the future.
The template for the policy itself should be versioned and named according to
the version: for example `1.0.html`. The version of the policy which the user
has agreed to is stored in the database.
Once the templates are in place, make the following changes to `homeserver.yaml`:
1. Add a `user_consent` section, which should look like:
```yaml
user_consent:
template_dir: privacy_policy_templates
version: 1.0
```
`template_dir` points to the directory containing the policy
templates. `version` defines the version of the policy which will be served
to the user. In the example above, Synapse will serve
`privacy_policy_templates/en/1.0.html`.
2. Add a `form_secret` setting at the top level:
```yaml
form_secret: "<unique secret>"
```
This should be set to an arbitrary secret string (try `pwgen -y 30` to
generate suitable secrets).
More on what this is used for below.
3. Add `consent` wherever the `client` resource is currently enabled in the
`listeners` configuration. For example:
```yaml
listeners:
- port: 8008
resources:
- names:
- client
- consent
```
Finally, ensure that `jinja2` is installed. If you are using a virtualenv, this
should be a matter of `pip install Jinja2`. On debian, try `apt-get install
python-jinja2`.
Once this is complete, and the server has been restarted, try visiting
`https://<server>/_matrix/consent`. If correctly configured, this should give
an error "Missing string query parameter 'u'". It is now possible to manually
construct URIs where users can give their consent.
### Constructing the consent URI
It may be useful to manually construct the "consent URI" for a given user - for
instance, in order to send them an email asking them to consent. To do this,
take the base `https://<server>/_matrix/consent` URL and add the following
query parameters:
* `u`: the user id of the user. This can either be a full MXID
(`@user:server.com`) or just the localpart (`user`).
* `h`: hex-encoded HMAC-SHA256 of `u` using the `form_secret` as a key. It is
possible to calculate this on the commandline with something like:
```bash
echo -n '<user>' | openssl sha256 -hmac '<form_secret>'
```
This should result in a URI which looks something like:
`https://<server>/_matrix/consent?u=<user>&h=68a152465a4d...`.
Sending users a server notice asking them to agree to the policy
----------------------------------------------------------------
It is possible to configure Synapse to send a [server
notice](server_notices.md) to anybody who has not yet agreed to the current
version of the policy. To do so:
* ensure that the consent resource is configured, as in the previous section
* ensure that server notices are configured, as in [server_notices.md](server_notices.md).
* Add `server_notice_content` under `user_consent` in `homeserver.yaml`. For
example:
```yaml
user_consent:
server_notice_content:
msgtype: m.text
body: >-
Please give your consent to the privacy policy at %(consent_uri)s.
```
Synapse automatically replaces the placeholder `%(consent_uri)s` with the
consent uri for that user.
* ensure that `public_baseurl` is set in `homeserver.yaml`, and gives the base
URI that clients use to connect to the server. (It is used to construct
`consent_uri` in the server notice.)
Blocking users from using the server until they agree to the policy
-------------------------------------------------------------------
Synapse can be configured to block any attempts to join rooms or send messages
until the user has given their agreement to the policy. (Joining the server
notices room is exempted from this).
To enable this, add `block_events_error` under `user_consent`. For example:
```yaml
user_consent:
block_events_error: >-
You can't send any messages until you consent to the privacy policy at
%(consent_uri)s.
```
Synapse automatically replaces the placeholder `%(consent_uri)s` with the
consent uri for that user.
ensure that `public_baseurl` is set in `homeserver.yaml`, and gives the base
URI that clients use to connect to the server. (It is used to construct
`consent_uri` in the error.)

43
docs/manhole.md Normal file
View file

@ -0,0 +1,43 @@
Using the synapse manhole
=========================
The "manhole" allows server administrators to access a Python shell on a running
Synapse installation. This is a very powerful mechanism for administration and
debugging.
To enable it, first uncomment the `manhole` listener configuration in
`homeserver.yaml`:
```yaml
listeners:
- port: 9000
bind_addresses: ['::1', '127.0.0.1']
type: manhole
```
(`bind_addresses` in the above is important: it ensures that access to the
manhole is only possible for local users).
Note that this will give administrative access to synapse to **all users** with
shell access to the server. It should therefore **not** be enabled in
environments where untrusted users have shell access.
Then restart synapse, and point an ssh client at port 9000 on localhost, using
the username `matrix`:
```bash
ssh -p9000 matrix@localhost
```
The password is `rabbithole`.
This gives a Python REPL in which `hs` gives access to the
`synapse.server.HomeServer` object - which in turn gives access to many other
parts of the process.
As a simple example, retrieving an event from the database:
```
>>> hs.get_datastore().get_event('$1416420717069yeQaw:matrix.org')
<Deferred at 0x7ff253fc6998 current result: <FrozenEvent event_id='$1416420717069yeQaw:matrix.org', type='m.room.create', state_key=''>>
```

View file

@ -1,23 +0,0 @@
If enabling the 'consent' resource in synapse, you will need some templates
for the HTML to be served to the user. This directory contains very simple
examples of the sort of thing that can be done.
You'll need to add this sort of thing to your homeserver.yaml:
```
form_secret: <unique but arbitrary secret>
user_consent:
template_dir: docs/privacy_policy_templates
version: 1.0
```
You should then be able to enable the `consent` resource under a `listener`
entry. For example:
```
listeners:
- port: 8008
resources:
- names: [client, consent]
```

74
docs/server_notices.md Normal file
View file

@ -0,0 +1,74 @@
Server Notices
==============
'Server Notices' are a new feature introduced in Synapse 0.30. They provide a
channel whereby server administrators can send messages to users on the server.
They are used as part of communication of the server polices(see
[consent_tracking.md](consent_tracking.md)), however the intention is that
they may also find a use for features such as "Message of the day".
This is a feature specific to Synapse, but it uses standard Matrix
communication mechanisms, so should work with any Matrix client.
User experience
---------------
When the user is first sent a server notice, they will get an invitation to a
room (typically called 'Server Notices', though this is configurable in
`homeserver.yaml`). They will be **unable to reject** this invitation -
attempts to do so will receive an error.
Once they accept the invitation, they will see the notice message in the room
history; it will appear to have come from the 'server notices user' (see
below).
The user is prevented from sending any messages in this room by the power
levels.
Having joined the room, the user can leave the room if they want. Subsequent
server notices will then cause a new room to be created.
Synapse configuration
---------------------
Server notices come from a specific user id on the server. Server
administrators are free to choose the user id - something like `server` is
suggested, meaning the notices will come from
`@server:<your_server_name>`. Once the Server Notices user is configured, that
user id becomes a special, privileged user, so administrators should ensure
that **it is not already allocated**.
In order to support server notices, it is necessary to add some configuration
to the `homeserver.yaml` file. In particular, you should add a `server_notices`
section, which should look like this:
```yaml
server_notices:
system_mxid_localpart: server
system_mxid_display_name: "Server Notices"
system_mxid_avatar_url: "mxc://server.com/oumMVlgDnLYFaPVkExemNVVZ"
room_name: "Server Notices"
```
The only compulsory setting is `system_mxid_localpart`, which defines the user
id of the Server Notices user, as above. `room_name` defines the name of the
room which will be created.
`system_mxid_display_name` and `system_mxid_avatar_url` can be used to set the
displayname and avatar of the Server Notices user.
Sending notices
---------------
As of the current version of synapse, there is no convenient interface for
sending notices (other than the automated ones sent as part of consent
tracking).
In the meantime, it is possible to test this feature using the manhole. Having
gone into the manhole as described in [manhole.md](manhole.md), a notice can be
sent with something like:
```
>>> hs.get_server_notices_manager().send_notice('@user:server.com', {'msgtype':'m.text', 'body':'foo'})
```

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.29.1" __version__ = "0.30.0"

View file

@ -53,6 +53,7 @@ class Codes(object):
INVALID_USERNAME = "M_INVALID_USERNAME" INVALID_USERNAME = "M_INVALID_USERNAME"
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN" CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN"
CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM"
class CodeMessageException(RuntimeError): class CodeMessageException(RuntimeError):

View file

@ -32,7 +32,8 @@ DEFAULT_CONFIG = """\
# #
# 'server_notice_content', if enabled, will send a user a "Server Notice" # 'server_notice_content', if enabled, will send a user a "Server Notice"
# asking them to consent to the privacy policy. The 'server_notices' section # asking them to consent to the privacy policy. The 'server_notices' section
# must also be configured for this to work. # must also be configured for this to work. Notices will *not* be sent to
# guest users unless 'send_server_notice_to_guests' is set to true.
# #
# 'block_events_error', if set, will block any attempts to send events # 'block_events_error', if set, will block any attempts to send events
# until the user consents to the privacy policy. The value of the setting is # until the user consents to the privacy policy. The value of the setting is
@ -43,10 +44,14 @@ DEFAULT_CONFIG = """\
# version: 1.0 # version: 1.0
# server_notice_content: # server_notice_content:
# msgtype: m.text # msgtype: m.text
# body: | # body: >-
# Pls do consent kthx # To continue using this homeserver you must review and agree to the
# block_events_error: | # terms and conditions at %(consent_uri)s
# You can't send any messages until you consent to the privacy policy. # send_server_notice_to_guests: True
# block_events_error: >-
# To continue using this homeserver you must review and agree to the
# terms and conditions at %(consent_uri)s
#
""" """
@ -57,6 +62,7 @@ class ConsentConfig(Config):
self.user_consent_version = None self.user_consent_version = None
self.user_consent_template_dir = None self.user_consent_template_dir = None
self.user_consent_server_notice_content = None self.user_consent_server_notice_content = None
self.user_consent_server_notice_to_guests = False
self.block_events_without_consent_error = None self.block_events_without_consent_error = None
def read_config(self, config): def read_config(self, config):
@ -71,6 +77,9 @@ class ConsentConfig(Config):
self.block_events_without_consent_error = consent_config.get( self.block_events_without_consent_error = consent_config.get(
"block_events_error", "block_events_error",
) )
self.user_consent_server_notice_to_guests = bool(consent_config.get(
"send_server_notice_to_guests", False,
))
def default_config(self, **kwargs): def default_config(self, **kwargs):
return DEFAULT_CONFIG return DEFAULT_CONFIG

View file

@ -26,12 +26,13 @@ DEFAULT_CONFIG = """\
# setting, which defines the id of the user which will be used to send the # setting, which defines the id of the user which will be used to send the
# notices. # notices.
# #
# It's also possible to override the room name, or the display name of the # It's also possible to override the room name, the display name of the
# "notices" user. # "notices" user, and the avatar for the user.
# #
# server_notices: # server_notices:
# system_mxid_localpart: notices # system_mxid_localpart: notices
# system_mxid_display_name: "Server Notices" # system_mxid_display_name: "Server Notices"
# system_mxid_avatar_url: "mxc://server.com/oumMVlgDnLYFaPVkExemNVVZ"
# room_name: "Server Notices" # room_name: "Server Notices"
""" """
@ -48,6 +49,10 @@ class ServerNoticesConfig(Config):
The display name to use for the server notices user. The display name to use for the server notices user.
None if server notices are not enabled. None if server notices are not enabled.
server_notices_mxid_avatar_url (str|None):
The display name to use for the server notices user.
None if server notices are not enabled.
server_notices_room_name (str|None): server_notices_room_name (str|None):
The name to use for the server notices room. The name to use for the server notices room.
None if server notices are not enabled. None if server notices are not enabled.
@ -56,6 +61,7 @@ class ServerNoticesConfig(Config):
super(ServerNoticesConfig, self).__init__() super(ServerNoticesConfig, self).__init__()
self.server_notices_mxid = None self.server_notices_mxid = None
self.server_notices_mxid_display_name = None self.server_notices_mxid_display_name = None
self.server_notices_mxid_avatar_url = None
self.server_notices_room_name = None self.server_notices_room_name = None
def read_config(self, config): def read_config(self, config):
@ -68,7 +74,10 @@ class ServerNoticesConfig(Config):
mxid_localpart, self.server_name, mxid_localpart, self.server_name,
).to_string() ).to_string()
self.server_notices_mxid_display_name = c.get( self.server_notices_mxid_display_name = c.get(
'system_mxid_display_name', 'Server Notices', 'system_mxid_display_name', None,
)
self.server_notices_mxid_avatar_url = c.get(
'system_mxid_avatar_url', None,
) )
# todo: i18n # todo: i18n
self.server_notices_room_name = c.get('room_name', "Server Notices") self.server_notices_room_name = c.get('room_name', "Server Notices")

View file

@ -20,6 +20,8 @@ from frozendict import frozendict
import re import re
from six import string_types
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
# (?<!stuff) matches if the current position in the string is not preceded # (?<!stuff) matches if the current position in the string is not preceded
# by a match for 'stuff'. # by a match for 'stuff'.
@ -277,7 +279,7 @@ def serialize_event(e, time_now_ms, as_client_event=True,
if only_event_fields: if only_event_fields:
if (not isinstance(only_event_fields, list) or if (not isinstance(only_event_fields, list) or
not all(isinstance(f, basestring) for f in only_event_fields)): not all(isinstance(f, string_types) for f in only_event_fields)):
raise TypeError("only_event_fields must be a list of strings") raise TypeError("only_event_fields must be a list of strings")
d = only_fields(d, only_event_fields) d = only_fields(d, only_event_fields)

View file

@ -17,6 +17,8 @@ from synapse.types import EventID, RoomID, UserID
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from six import string_types
class EventValidator(object): class EventValidator(object):
@ -49,7 +51,7 @@ class EventValidator(object):
strings.append("state_key") strings.append("state_key")
for s in strings: for s in strings:
if not isinstance(getattr(event, s), basestring): if not isinstance(getattr(event, s), string_types):
raise SynapseError(400, "Not '%s' a string type" % (s,)) raise SynapseError(400, "Not '%s' a string type" % (s,))
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
@ -88,5 +90,5 @@ class EventValidator(object):
for s in keys: for s in keys:
if s not in d: if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,)) raise SynapseError(400, "'%s' not in content" % (s,))
if not isinstance(d[s], basestring): if not isinstance(d[s], string_types):
raise SynapseError(400, "Not '%s' a string type" % (s,)) raise SynapseError(400, "Not '%s' a string type" % (s,))

View file

@ -20,6 +20,8 @@ from synapse.api.errors import SynapseError
from synapse.types import GroupID, RoomID, UserID, get_domain_from_id from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
from twisted.internet import defer from twisted.internet import defer
from six import string_types
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -431,7 +433,7 @@ class GroupsServerHandler(object):
"long_description"): "long_description"):
if keyname in content: if keyname in content:
value = content[keyname] value = content[keyname]
if not isinstance(value, basestring): if not isinstance(value, string_types):
raise SynapseError(400, "%r value is not a string" % (keyname,)) raise SynapseError(400, "%r value is not a string" % (keyname,))
profile[keyname] = value profile[keyname] = value

View file

@ -30,6 +30,7 @@ class DeactivateAccountHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler() self._device_handler = hs.get_device_handler()
self._room_member_handler = hs.get_room_member_handler() self._room_member_handler = hs.get_room_member_handler()
self.user_directory_handler = hs.get_user_directory_handler()
# Flag that indicates whether the process to part users from rooms is running # Flag that indicates whether the process to part users from rooms is running
self._user_parter_running = False self._user_parter_running = False
@ -61,10 +62,13 @@ class DeactivateAccountHandler(BaseHandler):
yield self.store.user_delete_threepids(user_id) yield self.store.user_delete_threepids(user_id)
yield self.store.user_set_password_hash(user_id, None) yield self.store.user_set_password_hash(user_id, None)
# Add the user to a table of users penpding deactivation (ie. # Add the user to a table of users pending deactivation (ie.
# removal from all the rooms they're a member of) # removal from all the rooms they're a member of)
yield self.store.add_user_pending_deactivation(user_id) yield self.store.add_user_pending_deactivation(user_id)
# delete from user directory
yield self.user_directory_handler.handle_user_deactivated(user_id)
# Now start the process that goes through that list and # Now start the process that goes through that list and
# parts users from rooms (if it isn't already running) # parts users from rooms (if it isn't already running)
self._start_user_parting() self._start_user_parting()

View file

@ -26,6 +26,8 @@ from ._base import BaseHandler
import logging import logging
from six import itervalues, iteritems
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -318,7 +320,7 @@ class DeviceHandler(BaseHandler):
# The user may have left the room # The user may have left the room
# TODO: Check if they actually did or if we were just invited. # TODO: Check if they actually did or if we were just invited.
if room_id not in room_ids: if room_id not in room_ids:
for key, event_id in current_state_ids.iteritems(): for key, event_id in iteritems(current_state_ids):
etype, state_key = key etype, state_key = key
if etype != EventTypes.Member: if etype != EventTypes.Member:
continue continue
@ -338,7 +340,7 @@ class DeviceHandler(BaseHandler):
# special-case for an empty prev state: include all members # special-case for an empty prev state: include all members
# in the changed list # in the changed list
if not event_ids: if not event_ids:
for key, event_id in current_state_ids.iteritems(): for key, event_id in iteritems(current_state_ids):
etype, state_key = key etype, state_key = key
if etype != EventTypes.Member: if etype != EventTypes.Member:
continue continue
@ -354,10 +356,10 @@ class DeviceHandler(BaseHandler):
# Check if we've joined the room? If so we just blindly add all the users to # Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users. # the "possibly changed" users.
for state_dict in prev_state_ids.itervalues(): for state_dict in itervalues(prev_state_ids):
member_event = state_dict.get((EventTypes.Member, user_id), None) member_event = state_dict.get((EventTypes.Member, user_id), None)
if not member_event or member_event != current_member_id: if not member_event or member_event != current_member_id:
for key, event_id in current_state_ids.iteritems(): for key, event_id in iteritems(current_state_ids):
etype, state_key = key etype, state_key = key
if etype != EventTypes.Member: if etype != EventTypes.Member:
continue continue
@ -367,14 +369,14 @@ class DeviceHandler(BaseHandler):
# If there has been any change in membership, include them in the # If there has been any change in membership, include them in the
# possibly changed list. We'll check if they are joined below, # possibly changed list. We'll check if they are joined below,
# and we're not toooo worried about spuriously adding users. # and we're not toooo worried about spuriously adding users.
for key, event_id in current_state_ids.iteritems(): for key, event_id in iteritems(current_state_ids):
etype, state_key = key etype, state_key = key
if etype != EventTypes.Member: if etype != EventTypes.Member:
continue continue
# check if this member has changed since any of the extremities # check if this member has changed since any of the extremities
# at the stream_ordering, and add them to the list if so. # at the stream_ordering, and add them to the list if so.
for state_dict in prev_state_ids.itervalues(): for state_dict in itervalues(prev_state_ids):
prev_event_id = state_dict.get(key, None) prev_event_id = state_dict.get(key, None)
if not prev_event_id or prev_event_id != event_id: if not prev_event_id or prev_event_id != event_id:
if state_key != user_id: if state_key != user_id:

View file

@ -19,6 +19,7 @@ import logging
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from twisted.internet import defer from twisted.internet import defer
from six import iteritems
from synapse.api.errors import ( from synapse.api.errors import (
SynapseError, CodeMessageException, FederationDeniedError, SynapseError, CodeMessageException, FederationDeniedError,
@ -92,7 +93,7 @@ class E2eKeysHandler(object):
remote_queries_not_in_cache = {} remote_queries_not_in_cache = {}
if remote_queries: if remote_queries:
query_list = [] query_list = []
for user_id, device_ids in remote_queries.iteritems(): for user_id, device_ids in iteritems(remote_queries):
if device_ids: if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids) query_list.extend((user_id, device_id) for device_id in device_ids)
else: else:
@ -103,9 +104,9 @@ class E2eKeysHandler(object):
query_list query_list
) )
) )
for user_id, devices in remote_results.iteritems(): for user_id, devices in iteritems(remote_results):
user_devices = results.setdefault(user_id, {}) user_devices = results.setdefault(user_id, {})
for device_id, device in devices.iteritems(): for device_id, device in iteritems(devices):
keys = device.get("keys", None) keys = device.get("keys", None)
device_display_name = device.get("device_display_name", None) device_display_name = device.get("device_display_name", None)
if keys: if keys:
@ -250,9 +251,9 @@ class E2eKeysHandler(object):
"Claimed one-time-keys: %s", "Claimed one-time-keys: %s",
",".join(( ",".join((
"%s for %s:%s" % (key_id, user_id, device_id) "%s for %s:%s" % (key_id, user_id, device_id)
for user_id, user_keys in json_result.iteritems() for user_id, user_keys in iteritems(json_result)
for device_id, device_keys in user_keys.iteritems() for device_id, device_keys in iteritems(user_keys)
for key_id, _ in device_keys.iteritems() for key_id, _ in iteritems(device_keys)
)), )),
) )

View file

@ -24,6 +24,7 @@ from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json from signedjson.sign import verify_signed_json
import six import six
from six.moves import http_client from six.moves import http_client
from six import iteritems
from twisted.internet import defer from twisted.internet import defer
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
@ -479,18 +480,18 @@ class FederationHandler(BaseHandler):
# to get all state ids that we're interested in. # to get all state ids that we're interested in.
event_map = yield self.store.get_events([ event_map = yield self.store.get_events([
e_id e_id
for key_to_eid in event_to_state_ids.values() for key_to_eid in event_to_state_ids.itervalues()
for key, e_id in key_to_eid.items() for key, e_id in key_to_eid.iteritems()
if key[0] != EventTypes.Member or check_match(key[1]) if key[0] != EventTypes.Member or check_match(key[1])
]) ])
event_to_state = { event_to_state = {
e_id: { e_id: {
key: event_map[inner_e_id] key: event_map[inner_e_id]
for key, inner_e_id in key_to_eid.items() for key, inner_e_id in key_to_eid.iteritems()
if inner_e_id in event_map if inner_e_id in event_map
} }
for e_id, key_to_eid in event_to_state_ids.items() for e_id, key_to_eid in event_to_state_ids.iteritems()
} }
def redact_disallowed(event, state): def redact_disallowed(event, state):
@ -505,7 +506,7 @@ class FederationHandler(BaseHandler):
# membership states for the requesting server to determine # membership states for the requesting server to determine
# if the server is either in the room or has been invited # if the server is either in the room or has been invited
# into the room. # into the room.
for ev in state.values(): for ev in state.itervalues():
if ev.type != EventTypes.Member: if ev.type != EventTypes.Member:
continue continue
try: try:
@ -751,9 +752,19 @@ class FederationHandler(BaseHandler):
curr_state = yield self.state_handler.get_current_state(room_id) curr_state = yield self.state_handler.get_current_state(room_id)
def get_domains_from_state(state): def get_domains_from_state(state):
"""Get joined domains from state
Args:
state (dict[tuple, FrozenEvent]): State map from type/state
key to event.
Returns:
list[tuple[str, int]]: Returns a list of servers with the
lowest depth of their joins. Sorted by lowest depth first.
"""
joined_users = [ joined_users = [
(state_key, int(event.depth)) (state_key, int(event.depth))
for (e_type, state_key), event in state.items() for (e_type, state_key), event in state.iteritems()
if e_type == EventTypes.Member if e_type == EventTypes.Member
and event.membership == Membership.JOIN and event.membership == Membership.JOIN
] ]
@ -770,7 +781,7 @@ class FederationHandler(BaseHandler):
except Exception: except Exception:
pass pass
return sorted(joined_domains.items(), key=lambda d: d[1]) return sorted(joined_domains.iteritems(), key=lambda d: d[1])
curr_domains = get_domains_from_state(curr_state) curr_domains = get_domains_from_state(curr_state)
@ -787,7 +798,7 @@ class FederationHandler(BaseHandler):
yield self.backfill( yield self.backfill(
dom, room_id, dom, room_id,
limit=100, limit=100,
extremities=[e for e in extremities.keys()] extremities=extremities,
) )
# If this succeeded then we probably already have the # If this succeeded then we probably already have the
# appropriate stuff. # appropriate stuff.
@ -833,7 +844,7 @@ class FederationHandler(BaseHandler):
tried_domains = set(likely_domains) tried_domains = set(likely_domains)
tried_domains.add(self.server_name) tried_domains.add(self.server_name)
event_ids = list(extremities.keys()) event_ids = list(extremities.iterkeys())
logger.debug("calling resolve_state_groups in _maybe_backfill") logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = logcontext.preserve_fn( resolve = logcontext.preserve_fn(
@ -843,31 +854,34 @@ class FederationHandler(BaseHandler):
[resolve(room_id, [e]) for e in event_ids], [resolve(room_id, [e]) for e in event_ids],
consumeErrors=True, consumeErrors=True,
)) ))
# dict[str, dict[tuple, str]], a map from event_id to state map of
# event_ids.
states = dict(zip(event_ids, [s.state for s in states])) states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events( state_map = yield self.store.get_events(
[e_id for ids in states.values() for e_id in ids], [e_id for ids in states.itervalues() for e_id in ids.itervalues()],
get_prev_content=False get_prev_content=False
) )
states = { states = {
key: { key: {
k: state_map[e_id] k: state_map[e_id]
for k, e_id in state_dict.items() for k, e_id in state_dict.iteritems()
if e_id in state_map if e_id in state_map
} for key, state_dict in states.items() } for key, state_dict in states.iteritems()
} }
for e_id, _ in sorted_extremeties_tuple: for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id]) likely_domains = get_domains_from_state(states[e_id])
success = yield try_backfill([ success = yield try_backfill([
dom for dom in likely_domains dom for dom, _ in likely_domains
if dom not in tried_domains if dom not in tried_domains
]) ])
if success: if success:
defer.returnValue(True) defer.returnValue(True)
tried_domains.update(likely_domains) tried_domains.update(dom for dom, _ in likely_domains)
defer.returnValue(False) defer.returnValue(False)
@ -1375,7 +1389,7 @@ class FederationHandler(BaseHandler):
) )
if state_groups: if state_groups:
_, state = state_groups.items().pop() _, state = list(iteritems(state_groups)).pop()
results = { results = {
(e.type, e.state_key): e for e in state (e.type, e.state_key): e for e in state
} }
@ -2021,7 +2035,7 @@ class FederationHandler(BaseHandler):
this will not be included in the current_state in the context. this will not be included in the current_state in the context.
""" """
state_updates = { state_updates = {
k: a.event_id for k, a in auth_events.iteritems() k: a.event_id for k, a in iteritems(auth_events)
if k != event_key if k != event_key
} }
context.current_state_ids = dict(context.current_state_ids) context.current_state_ids = dict(context.current_state_ids)
@ -2031,7 +2045,7 @@ class FederationHandler(BaseHandler):
context.delta_ids.update(state_updates) context.delta_ids.update(state_updates)
context.prev_state_ids = dict(context.prev_state_ids) context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({ context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.iteritems() k: a.event_id for k, a in iteritems(auth_events)
}) })
context.state_group = yield self.store.store_state_group( context.state_group = yield self.store.store_state_group(
event.event_id, event.event_id,
@ -2083,7 +2097,7 @@ class FederationHandler(BaseHandler):
def get_next(it, opt=None): def get_next(it, opt=None):
try: try:
return it.next() return next(it)
except Exception: except Exception:
return opt return opt

View file

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from six import iteritems
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
@ -449,7 +450,7 @@ class GroupsLocalHandler(object):
results = {} results = {}
failed_results = [] failed_results = []
for destination, dest_user_ids in destinations.iteritems(): for destination, dest_user_ids in iteritems(destinations):
try: try:
r = yield self.transport_client.bulk_get_publicised_groups( r = yield self.transport_client.bulk_get_publicised_groups(
destination, list(dest_user_ids), destination, list(dest_user_ids),

View file

@ -19,6 +19,7 @@ import sys
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
import six import six
from six import string_types, itervalues, iteritems
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.defer import succeed from twisted.internet.defer import succeed
from twisted.python.failure import Failure from twisted.python.failure import Failure
@ -402,7 +403,7 @@ class MessageHandler(BaseHandler):
"avatar_url": profile.avatar_url, "avatar_url": profile.avatar_url,
"display_name": profile.display_name, "display_name": profile.display_name,
} }
for user_id, profile in users_with_profile.iteritems() for user_id, profile in iteritems(users_with_profile)
}) })
@ -574,9 +575,14 @@ class EventCreationHandler(object):
if u["consent_version"] == self.config.user_consent_version: if u["consent_version"] == self.config.user_consent_version:
return return
consent_uri = self._consent_uri_builder.build_user_consent_uri(user_id) consent_uri = self._consent_uri_builder.build_user_consent_uri(
requester.user.localpart,
)
msg = self.config.block_events_without_consent_error % {
'consent_uri': consent_uri,
}
raise ConsentNotGivenError( raise ConsentNotGivenError(
msg=self.config.block_events_without_consent_error, msg=msg,
consent_uri=consent_uri, consent_uri=consent_uri,
) )
@ -662,7 +668,7 @@ class EventCreationHandler(object):
spam_error = self.spam_checker.check_event_for_spam(event) spam_error = self.spam_checker.check_event_for_spam(event)
if spam_error: if spam_error:
if not isinstance(spam_error, basestring): if not isinstance(spam_error, string_types):
spam_error = "Spam is not permitted here" spam_error = "Spam is not permitted here"
raise SynapseError( raise SynapseError(
403, spam_error, Codes.FORBIDDEN 403, spam_error, Codes.FORBIDDEN
@ -876,7 +882,7 @@ class EventCreationHandler(object):
state_to_include_ids = [ state_to_include_ids = [
e_id e_id
for k, e_id in context.current_state_ids.iteritems() for k, e_id in iteritems(context.current_state_ids)
if k[0] in self.hs.config.room_invite_state_types if k[0] in self.hs.config.room_invite_state_types
or k == (EventTypes.Member, event.sender) or k == (EventTypes.Member, event.sender)
] ]
@ -890,7 +896,7 @@ class EventCreationHandler(object):
"content": e.content, "content": e.content,
"sender": e.sender, "sender": e.sender,
} }
for e in state_to_include.itervalues() for e in itervalues(state_to_include)
] ]
invitee = UserID.from_string(event.state_key) invitee = UserID.from_string(event.state_key)

View file

@ -25,6 +25,8 @@ The methods that define policy are:
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from contextlib import contextmanager from contextlib import contextmanager
from six import itervalues, iteritems
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.storage.presence import UserPresenceState from synapse.storage.presence import UserPresenceState
@ -42,7 +44,6 @@ import logging
from prometheus_client import Counter from prometheus_client import Counter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -535,7 +536,7 @@ class PresenceHandler(object):
prev_state.copy_and_replace( prev_state.copy_and_replace(
last_user_sync_ts=time_now_ms, last_user_sync_ts=time_now_ms,
) )
for prev_state in prev_states.itervalues() for prev_state in itervalues(prev_states)
]) ])
self.external_process_last_updated_ms.pop(process_id, None) self.external_process_last_updated_ms.pop(process_id, None)
@ -558,14 +559,14 @@ class PresenceHandler(object):
for user_id in user_ids for user_id in user_ids
} }
missing = [user_id for user_id, state in states.iteritems() if not state] missing = [user_id for user_id, state in iteritems(states) if not state]
if missing: if missing:
# There are things not in our in memory cache. Lets pull them out of # There are things not in our in memory cache. Lets pull them out of
# the database. # the database.
res = yield self.store.get_presence_for_users(missing) res = yield self.store.get_presence_for_users(missing)
states.update(res) states.update(res)
missing = [user_id for user_id, state in states.iteritems() if not state] missing = [user_id for user_id, state in iteritems(states) if not state]
if missing: if missing:
new = { new = {
user_id: UserPresenceState.default(user_id) user_id: UserPresenceState.default(user_id)
@ -1053,7 +1054,7 @@ class PresenceEventSource(object):
defer.returnValue((updates.values(), max_token)) defer.returnValue((updates.values(), max_token))
else: else:
defer.returnValue(([ defer.returnValue(([
s for s in updates.itervalues() s for s in itervalues(updates)
if s.state != PresenceState.OFFLINE if s.state != PresenceState.OFFLINE
], max_token)) ], max_token))
@ -1310,11 +1311,11 @@ def get_interested_remotes(store, states, state_handler):
# hosts in those rooms. # hosts in those rooms.
room_ids_to_states, users_to_states = yield get_interested_parties(store, states) room_ids_to_states, users_to_states = yield get_interested_parties(store, states)
for room_id, states in room_ids_to_states.iteritems(): for room_id, states in iteritems(room_ids_to_states):
hosts = yield state_handler.get_current_hosts_in_room(room_id) hosts = yield state_handler.get_current_hosts_in_room(room_id)
hosts_and_states.append((hosts, states)) hosts_and_states.append((hosts, states))
for user_id, states in users_to_states.iteritems(): for user_id, states in iteritems(users_to_states):
host = get_domain_from_id(user_id) host = get_domain_from_id(user_id)
hosts_and_states.append(([host], states)) hosts_and_states.append(([host], states))

View file

@ -298,15 +298,6 @@ class RoomMemberHandler(object):
is_blocked = yield self.store.is_room_blocked(room_id) is_blocked = yield self.store.is_room_blocked(room_id)
if is_blocked: if is_blocked:
raise SynapseError(403, "This room has been blocked on this server") raise SynapseError(403, "This room has been blocked on this server")
else:
# we don't allow people to reject invites to, or leave, the
# server notice room.
is_blocked = yield self._is_server_notice_room(room_id)
if is_blocked:
raise SynapseError(
http_client.FORBIDDEN,
"You cannot leave this room",
)
if effective_membership_state == Membership.INVITE: if effective_membership_state == Membership.INVITE:
# block any attempts to invite the server notices mxid # block any attempts to invite the server notices mxid
@ -382,6 +373,20 @@ class RoomMemberHandler(object):
if same_sender and same_membership and same_content: if same_sender and same_membership and same_content:
defer.returnValue(old_state) defer.returnValue(old_state)
# we don't allow people to reject invites to the server notice
# room, but they can leave it once they are joined.
if (
old_membership == Membership.INVITE and
effective_membership_state == Membership.LEAVE
):
is_blocked = yield self._is_server_notice_room(room_id)
if is_blocked:
raise SynapseError(
http_client.FORBIDDEN,
"You cannot reject this invite",
errcode=Codes.CANNOT_LEAVE_SERVER_NOTICE_ROOM,
)
is_host_in_room = yield self._is_host_in_room(current_state_ids) is_host_in_room = yield self._is_host_in_room(current_state_ids)
if effective_membership_state == Membership.JOIN: if effective_membership_state == Membership.JOIN:

View file

@ -28,6 +28,8 @@ import collections
import logging import logging
import itertools import itertools
from six import itervalues, iteritems
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -275,7 +277,7 @@ class SyncHandler(object):
# result returned by the event source is poor form (it might cache # result returned by the event source is poor form (it might cache
# the object) # the object)
room_id = event["room_id"] room_id = event["room_id"]
event_copy = {k: v for (k, v) in event.iteritems() event_copy = {k: v for (k, v) in iteritems(event)
if k != "room_id"} if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy) ephemeral_by_room.setdefault(room_id, []).append(event_copy)
@ -294,7 +296,7 @@ class SyncHandler(object):
for event in receipts: for event in receipts:
room_id = event["room_id"] room_id = event["room_id"]
# exclude room id, as above # exclude room id, as above
event_copy = {k: v for (k, v) in event.iteritems() event_copy = {k: v for (k, v) in iteritems(event)
if k != "room_id"} if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy) ephemeral_by_room.setdefault(room_id, []).append(event_copy)
@ -325,7 +327,7 @@ class SyncHandler(object):
current_state_ids = frozenset() current_state_ids = frozenset()
if any(e.is_state() for e in recents): if any(e.is_state() for e in recents):
current_state_ids = yield self.state.get_current_state_ids(room_id) current_state_ids = yield self.state.get_current_state_ids(room_id)
current_state_ids = frozenset(current_state_ids.itervalues()) current_state_ids = frozenset(itervalues(current_state_ids))
recents = yield filter_events_for_client( recents = yield filter_events_for_client(
self.store, self.store,
@ -382,7 +384,7 @@ class SyncHandler(object):
current_state_ids = frozenset() current_state_ids = frozenset()
if any(e.is_state() for e in loaded_recents): if any(e.is_state() for e in loaded_recents):
current_state_ids = yield self.state.get_current_state_ids(room_id) current_state_ids = yield self.state.get_current_state_ids(room_id)
current_state_ids = frozenset(current_state_ids.itervalues()) current_state_ids = frozenset(itervalues(current_state_ids))
loaded_recents = yield filter_events_for_client( loaded_recents = yield filter_events_for_client(
self.store, self.store,
@ -984,7 +986,7 @@ class SyncHandler(object):
if since_token: if since_token:
for joined_sync in sync_result_builder.joined: for joined_sync in sync_result_builder.joined:
it = itertools.chain( it = itertools.chain(
joined_sync.timeline.events, joined_sync.state.itervalues() joined_sync.timeline.events, itervalues(joined_sync.state)
) )
for event in it: for event in it:
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
@ -1062,7 +1064,7 @@ class SyncHandler(object):
newly_left_rooms = [] newly_left_rooms = []
room_entries = [] room_entries = []
invited = [] invited = []
for room_id, events in mem_change_events_by_room_id.iteritems(): for room_id, events in iteritems(mem_change_events_by_room_id):
non_joins = [e for e in events if e.membership != Membership.JOIN] non_joins = [e for e in events if e.membership != Membership.JOIN]
has_join = len(non_joins) != len(events) has_join = len(non_joins) != len(events)

View file

@ -22,6 +22,7 @@ from synapse.util.metrics import Measure
from synapse.util.async import sleep from synapse.util.async import sleep
from synapse.types import get_localpart_from_id from synapse.types import get_localpart_from_id
from six import iteritems
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -122,6 +123,13 @@ class UserDirectoryHandler(object):
user_id, profile.display_name, profile.avatar_url, None, user_id, profile.display_name, profile.avatar_url, None,
) )
@defer.inlineCallbacks
def handle_user_deactivated(self, user_id):
"""Called when a user ID is deactivated
"""
yield self.store.remove_from_user_dir(user_id)
yield self.store.remove_from_user_in_public_room(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def _unsafe_process(self): def _unsafe_process(self):
# If self.pos is None then means we haven't fetched it from DB # If self.pos is None then means we haven't fetched it from DB
@ -403,7 +411,7 @@ class UserDirectoryHandler(object):
if change: if change:
users_with_profile = yield self.state.get_current_user_in_room(room_id) users_with_profile = yield self.state.get_current_user_in_room(room_id)
for user_id, profile in users_with_profile.iteritems(): for user_id, profile in iteritems(users_with_profile):
yield self._handle_new_user(room_id, user_id, profile) yield self._handle_new_user(room_id, user_id, profile)
else: else:
users = yield self.store.get_users_in_public_due_to_room(room_id) users = yield self.store.get_users_in_public_due_to_room(room_id)

View file

@ -42,6 +42,8 @@ import random
import sys import sys
import urllib import urllib
from six.moves.urllib import parse as urlparse from six.moves.urllib import parse as urlparse
from six import string_types
from prometheus_client import Counter from prometheus_client import Counter
@ -549,7 +551,7 @@ class MatrixFederationHttpClient(object):
encoded_args = {} encoded_args = {}
for k, vs in args.items(): for k, vs in args.items():
if isinstance(vs, basestring): if isinstance(vs, string_types):
vs = [vs] vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs] encoded_args[k] = [v.encode("UTF-8") for v in vs]
@ -664,7 +666,7 @@ def check_content_type_is_json(headers):
RuntimeError if the RuntimeError if the
""" """
c_type = headers.getRawHeaders("Content-Type") c_type = headers.getRawHeaders(b"Content-Type")
if c_type is None: if c_type is None:
raise RuntimeError( raise RuntimeError(
"No Content-Type header" "No Content-Type header"
@ -681,7 +683,7 @@ def check_content_type_is_json(headers):
def encode_query_args(args): def encode_query_args(args):
encoded_args = {} encoded_args = {}
for k, vs in args.items(): for k, vs in args.items():
if isinstance(vs, basestring): if isinstance(vs, string_types):
vs = [vs] vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs] encoded_args[k] = [v.encode("UTF-8") for v in vs]

View file

@ -56,7 +56,7 @@ class SynapseRequest(Request):
def __repr__(self): def __repr__(self):
# We overwrite this so that we don't log ``access_token`` # We overwrite this so that we don't log ``access_token``
return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % ( return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % (
self.__class__.__name__, self.__class__.__name__,
id(self), id(self),
self.method, self.method,

View file

@ -29,6 +29,7 @@ from synapse.state import POWER_KEY
from collections import namedtuple from collections import namedtuple
from prometheus_client import Counter from prometheus_client import Counter
from six import itervalues, iteritems
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -122,7 +123,7 @@ class BulkPushRuleEvaluator(object):
) )
auth_events = yield self.store.get_events(auth_events_ids) auth_events = yield self.store.get_events(auth_events_ids)
auth_events = { auth_events = {
(e.type, e.state_key): e for e in auth_events.itervalues() (e.type, e.state_key): e for e in itervalues(auth_events)
} }
sender_level = get_user_power_level(event.sender, auth_events) sender_level = get_user_power_level(event.sender, auth_events)
@ -156,7 +157,7 @@ class BulkPushRuleEvaluator(object):
condition_cache = {} condition_cache = {}
for uid, rules in rules_by_user.iteritems(): for uid, rules in iteritems(rules_by_user):
if event.sender == uid: if event.sender == uid:
continue continue
@ -402,7 +403,7 @@ class RulesForRoom(object):
# If the event is a join event then it will be in current state evnts # If the event is a join event then it will be in current state evnts
# map but not in the DB, so we have to explicitly insert it. # map but not in the DB, so we have to explicitly insert it.
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
for event_id in member_event_ids.itervalues(): for event_id in itervalues(member_event_ids):
if event_id == event.event_id: if event_id == event.event_id:
members[event_id] = (event.state_key, event.membership) members[event_id] = (event.state_key, event.membership)
@ -410,7 +411,7 @@ class RulesForRoom(object):
logger.debug("Found members %r: %r", self.room_id, members.values()) logger.debug("Found members %r: %r", self.room_id, members.values())
interested_in_user_ids = set( interested_in_user_ids = set(
user_id for user_id, membership in members.itervalues() user_id for user_id, membership in itervalues(members)
if membership == Membership.JOIN if membership == Membership.JOIN
) )
@ -422,7 +423,7 @@ class RulesForRoom(object):
) )
user_ids = set( user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.iteritems() if have_pusher uid for uid, have_pusher in iteritems(if_users_with_pushers) if have_pusher
) )
logger.debug("With pushers: %r", user_ids) logger.debug("With pushers: %r", user_ids)
@ -443,7 +444,7 @@ class RulesForRoom(object):
) )
ret_rules_by_user.update( ret_rules_by_user.update(
item for item in rules_by_user.iteritems() if item[0] is not None item for item in iteritems(rules_by_user) if item[0] is not None
) )
self.update_cache(sequence, members, ret_rules_by_user, state_group) self.update_cache(sequence, members, ret_rules_by_user, state_group)

View file

@ -21,6 +21,8 @@ from synapse.types import UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from six import string_types
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -238,7 +240,7 @@ def _flatten_dict(d, prefix=[], result=None):
if result is None: if result is None:
result = {} result = {}
for key, value in d.items(): for key, value in d.items():
if isinstance(value, basestring): if isinstance(value, string_types):
result[".".join(prefix + [key])] = value.lower() result[".".join(prefix + [key])] = value.lower()
elif hasattr(value, "items"): elif hasattr(value, "items"):
_flatten_dict(value, prefix=(prefix + [key]), result=result) _flatten_dict(value, prefix=(prefix + [key]), result=result)

View file

@ -67,14 +67,14 @@ from prometheus_client import Counter
from collections import defaultdict from collections import defaultdict
from six import iterkeys, iteritems
import logging import logging
import struct import struct
import fcntl import fcntl
connection_close_counter = Counter( connection_close_counter = Counter(
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"], "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"])
)
# A list of all connected protocols. This allows us to send metrics about the # A list of all connected protocols. This allows us to send metrics about the
# connections. # connections.
@ -389,7 +389,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
if stream_name == "ALL": if stream_name == "ALL":
# Subscribe to all streams we're publishing to. # Subscribe to all streams we're publishing to.
for stream in self.streamer.streams_by_name.iterkeys(): for stream in iterkeys(self.streamer.streams_by_name):
self.subscribe_to_stream(stream, token) self.subscribe_to_stream(stream, token)
else: else:
self.subscribe_to_stream(stream_name, token) self.subscribe_to_stream(stream_name, token)
@ -495,7 +495,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
BaseReplicationStreamProtocol.connectionMade(self) BaseReplicationStreamProtocol.connectionMade(self)
# Once we've connected subscribe to the necessary streams # Once we've connected subscribe to the necessary streams
for stream_name, token in self.handler.get_streams_to_replicate().iteritems(): for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
self.replicate(stream_name, token) self.replicate(stream_name, token)
# Tell the server if we have any users currently syncing (should only # Tell the server if we have any users currently syncing (should only
@ -622,7 +622,7 @@ tcp_inbound_commands = LaterGauge(
lambda: { lambda: {
(k[0], p.name, p.conn_id): count (k[0], p.name, p.conn_id): count
for p in connected_connections for p in connected_connections
for k, count in p.inbound_commands_counter.items() for k, count in iteritems(p.inbound_commands_counter.counts)
}) })
tcp_outbound_commands = LaterGauge( tcp_outbound_commands = LaterGauge(
@ -630,7 +630,7 @@ tcp_outbound_commands = LaterGauge(
lambda: { lambda: {
(k[0], p.name, p.conn_id): count (k[0], p.name, p.conn_id): count
for p in connected_connections for p in connected_connections
for k, count in p.outbound_commands_counter.items() for k, count in iteritems(p.outbound_commands_counter.counts)
}) })
# number of updates received for each RDATA stream # number of updates received for each RDATA stream

View file

@ -27,6 +27,7 @@ from synapse.metrics import LaterGauge
import logging import logging
from prometheus_client import Counter from prometheus_client import Counter
from six import itervalues
stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates", stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates",
"", ["stream_name"]) "", ["stream_name"])
@ -81,7 +82,7 @@ class ReplicationStreamer(object):
# We only support federation stream if federation sending hase been # We only support federation stream if federation sending hase been
# disabled on the master. # disabled on the master.
self.streams = [ self.streams = [
stream(hs) for stream in STREAMS_MAP.itervalues() stream(hs) for stream in itervalues(STREAMS_MAP)
if stream != FederationStream or not hs.config.send_federation if stream != FederationStream or not hs.config.send_federation
] ]

View file

@ -23,6 +23,8 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
from six import string_types
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -71,7 +73,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
if "status_msg" in content: if "status_msg" in content:
state["status_msg"] = content.pop("status_msg") state["status_msg"] = content.pop("status_msg")
if not isinstance(state["status_msg"], basestring): if not isinstance(state["status_msg"], string_types):
raise SynapseError(400, "status_msg must be a string.") raise SynapseError(400, "status_msg must be a string.")
if content: if content:
@ -129,7 +131,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
if "invite" in content: if "invite" in content:
for u in content["invite"]: for u in content["invite"]:
if not isinstance(u, basestring): if not isinstance(u, string_types):
raise SynapseError(400, "Bad invite value.") raise SynapseError(400, "Bad invite value.")
if len(u) == 0: if len(u) == 0:
continue continue
@ -140,7 +142,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
if "drop" in content: if "drop" in content:
for u in content["drop"]: for u in content["drop"]:
if not isinstance(u, basestring): if not isinstance(u, string_types):
raise SynapseError(400, "Bad drop value.") raise SynapseError(400, "Bad drop value.")
if len(u) == 0: if len(u) == 0:
continue continue

View file

@ -48,6 +48,7 @@ import shutil
import cgi import cgi
import logging import logging
from six.moves.urllib import parse as urlparse from six.moves.urllib import parse as urlparse
from six import iteritems
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -603,7 +604,7 @@ class MediaRepository(object):
thumbnails[(t_width, t_height, r_type)] = r_method thumbnails[(t_width, t_height, r_type)] = r_method
# Now we generate the thumbnails for each dimension, store it # Now we generate the thumbnails for each dimension, store it
for (t_width, t_height, t_type), t_method in thumbnails.iteritems(): for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
# Generate the thumbnail # Generate the thumbnail
if t_method == "crop": if t_method == "crop":
t_byte_source = yield make_deferred_yieldable(threads.deferToThread( t_byte_source = yield make_deferred_yieldable(threads.deferToThread(

View file

@ -24,7 +24,9 @@ import shutil
import sys import sys
import traceback import traceback
import simplejson as json import simplejson as json
import urlparse
from six.moves import urllib_parse as urlparse
from six import string_types
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer from twisted.internet import defer
@ -590,8 +592,8 @@ def _iterate_over_text(tree, *tags_to_ignore):
# to be returned. # to be returned.
elements = iter([tree]) elements = iter([tree])
while True: while True:
el = elements.next() el = next(elements)
if isinstance(el, basestring): if isinstance(el, string_types):
yield el yield el
elif el is not None and el.tag not in tags_to_ignore: elif el is not None and el.tag not in tags_to_ignore:
# el.text is the text before the first child, so we can immediately # el.text is the text before the first child, so we can immediately

View file

@ -14,10 +14,13 @@
# limitations under the License. # limitations under the License.
import logging import logging
from six import (iteritems, string_types)
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.urls import ConsentURIBuilder
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.types import get_localpart_from_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,6 +42,7 @@ class ConsentServerNotices(object):
self._current_consent_version = hs.config.user_consent_version self._current_consent_version = hs.config.user_consent_version
self._server_notice_content = hs.config.user_consent_server_notice_content self._server_notice_content = hs.config.user_consent_server_notice_content
self._send_to_guests = hs.config.user_consent_server_notice_to_guests
if self._server_notice_content is not None: if self._server_notice_content is not None:
if not self._server_notices_manager.is_enabled(): if not self._server_notices_manager.is_enabled():
@ -52,6 +56,8 @@ class ConsentServerNotices(object):
"key.", "key.",
) )
self._consent_uri_builder = ConsentURIBuilder(hs.config)
@defer.inlineCallbacks @defer.inlineCallbacks
def maybe_send_server_notice_to_user(self, user_id): def maybe_send_server_notice_to_user(self, user_id):
"""Check if we need to send a notice to this user, and does so if so """Check if we need to send a notice to this user, and does so if so
@ -73,6 +79,10 @@ class ConsentServerNotices(object):
try: try:
u = yield self._store.get_user_by_id(user_id) u = yield self._store.get_user_by_id(user_id)
if u["is_guest"] and not self._send_to_guests:
# don't send to guests
return
if u["consent_version"] == self._current_consent_version: if u["consent_version"] == self._current_consent_version:
# user has already consented # user has already consented
return return
@ -81,10 +91,18 @@ class ConsentServerNotices(object):
# we've already sent a notice to the user # we've already sent a notice to the user
return return
# need to send a message # need to send a message.
try: try:
consent_uri = self._consent_uri_builder.build_user_consent_uri(
get_localpart_from_id(user_id),
)
content = copy_with_str_subst(
self._server_notice_content, {
'consent_uri': consent_uri,
},
)
yield self._server_notices_manager.send_notice( yield self._server_notices_manager.send_notice(
user_id, self._server_notice_content, user_id, content,
) )
yield self._store.user_set_consent_server_notice_sent( yield self._store.user_set_consent_server_notice_sent(
user_id, self._current_consent_version, user_id, self._current_consent_version,
@ -93,3 +111,27 @@ class ConsentServerNotices(object):
logger.error("Error sending server notice about user consent: %s", e) logger.error("Error sending server notice about user consent: %s", e)
finally: finally:
self._users_in_progress.remove(user_id) self._users_in_progress.remove(user_id)
def copy_with_str_subst(x, substitutions):
"""Deep-copy a structure, carrying out string substitions on any strings
Args:
x (object): structure to be copied
substitutions (object): substitutions to be made - passed into the
string '%' operator
Returns:
copy of x
"""
if isinstance(x, string_types):
return x % substitutions
if isinstance(x, dict):
return {
k: copy_with_str_subst(v, substitutions) for (k, v) in iteritems(x)
}
if isinstance(x, (list, tuple)):
return [copy_with_str_subst(y) for y in x]
# assume it's uninterested and can be shallow-copied.
return x

View file

@ -35,6 +35,7 @@ class ServerNoticesManager(object):
self._config = hs.config self._config = hs.config
self._room_creation_handler = hs.get_room_creation_handler() self._room_creation_handler = hs.get_room_creation_handler()
self._event_creation_handler = hs.get_event_creation_handler() self._event_creation_handler = hs.get_event_creation_handler()
self._is_mine_id = hs.is_mine_id
def is_enabled(self): def is_enabled(self):
"""Checks if server notices are enabled on this server. """Checks if server notices are enabled on this server.
@ -55,7 +56,7 @@ class ServerNoticesManager(object):
event_content (dict): content of event to send event_content (dict): content of event to send
Returns: Returns:
Deferrred[None] Deferred[None]
""" """
room_id = yield self.get_notice_room_for_user(user_id) room_id = yield self.get_notice_room_for_user(user_id)
@ -89,6 +90,9 @@ class ServerNoticesManager(object):
if not self.is_enabled(): if not self.is_enabled():
raise Exception("Server notices not enabled") raise Exception("Server notices not enabled")
assert self._is_mine_id(user_id), \
"Cannot send server notices to remote users"
rooms = yield self._store.get_rooms_for_user_where_membership_is( rooms = yield self._store.get_rooms_for_user_where_membership_is(
user_id, [Membership.INVITE, Membership.JOIN], user_id, [Membership.INVITE, Membership.JOIN],
) )
@ -109,6 +113,19 @@ class ServerNoticesManager(object):
# apparently no existing notice room: create a new one # apparently no existing notice room: create a new one
logger.info("Creating server notices room for %s", user_id) logger.info("Creating server notices room for %s", user_id)
# see if we want to override the profile info for the server user.
# note that if we want to override either the display name or the
# avatar, we have to use both.
join_profile = None
if (
self._config.server_notices_mxid_display_name is not None or
self._config.server_notices_mxid_avatar_url is not None
):
join_profile = {
"displayname": self._config.server_notices_mxid_display_name,
"avatar_url": self._config.server_notices_mxid_avatar_url,
}
requester = create_requester(system_mxid) requester = create_requester(system_mxid)
info = yield self._room_creation_handler.create_room( info = yield self._room_creation_handler.create_room(
requester, requester,
@ -121,9 +138,7 @@ class ServerNoticesManager(object):
"invite": (user_id,) "invite": (user_id,)
}, },
ratelimit=False, ratelimit=False,
creator_join_profile={ creator_join_profile=join_profile,
"displayname": self._config.server_notices_mxid_display_name,
},
) )
room_id = info['room_id'] room_id = info['room_id']

View file

@ -32,7 +32,7 @@ class WorkerServerNoticesSender(object):
Returns: Returns:
Deferred Deferred
""" """
return defer.succeed() return defer.succeed(None)
def on_user_ip(self, user_id): def on_user_ip(self, user_id):
"""Called on the master when a worker process saw a client request. """Called on the master when a worker process saw a client request.

View file

@ -32,6 +32,8 @@ from frozendict import frozendict
import logging import logging
import hashlib import hashlib
from six import iteritems, itervalues
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -132,7 +134,7 @@ class StateHandler(object):
state_map = yield self.store.get_events(state.values(), get_prev_content=False) state_map = yield self.store.get_events(state.values(), get_prev_content=False)
state = { state = {
key: state_map[e_id] for key, e_id in state.iteritems() if e_id in state_map key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map
} }
defer.returnValue(state) defer.returnValue(state)
@ -338,7 +340,7 @@ class StateHandler(object):
) )
if len(state_groups_ids) == 1: if len(state_groups_ids) == 1:
name, state_list = state_groups_ids.items().pop() name, state_list = list(state_groups_ids.items()).pop()
prev_group, delta_ids = yield self.store.get_state_group_delta(name) prev_group, delta_ids = yield self.store.get_state_group_delta(name)
@ -378,7 +380,7 @@ class StateHandler(object):
new_state = resolve_events_with_state_map(state_set_ids, state_map) new_state = resolve_events_with_state_map(state_set_ids, state_map)
new_state = { new_state = {
key: state_map[ev_id] for key, ev_id in new_state.iteritems() key: state_map[ev_id] for key, ev_id in iteritems(new_state)
} }
return new_state return new_state
@ -458,15 +460,15 @@ class StateResolutionHandler(object):
# build a map from state key to the event_ids which set that state. # build a map from state key to the event_ids which set that state.
# dict[(str, str), set[str]) # dict[(str, str), set[str])
state = {} state = {}
for st in state_groups_ids.itervalues(): for st in itervalues(state_groups_ids):
for key, e_id in st.iteritems(): for key, e_id in iteritems(st):
state.setdefault(key, set()).add(e_id) state.setdefault(key, set()).add(e_id)
# build a map from state key to the event_ids which set that state, # build a map from state key to the event_ids which set that state,
# including only those where there are state keys in conflict. # including only those where there are state keys in conflict.
conflicted_state = { conflicted_state = {
k: list(v) k: list(v)
for k, v in state.iteritems() for k, v in iteritems(state)
if len(v) > 1 if len(v) > 1
} }
@ -474,13 +476,13 @@ class StateResolutionHandler(object):
logger.info("Resolving conflicted state for %r", room_id) logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_factory( new_state = yield resolve_events_with_factory(
state_groups_ids.values(), list(state_groups_ids.values()),
event_map=event_map, event_map=event_map,
state_map_factory=state_map_factory, state_map_factory=state_map_factory,
) )
else: else:
new_state = { new_state = {
key: e_ids.pop() for key, e_ids in state.iteritems() key: e_ids.pop() for key, e_ids in iteritems(state)
} }
with Measure(self.clock, "state.create_group_ids"): with Measure(self.clock, "state.create_group_ids"):
@ -489,8 +491,8 @@ class StateResolutionHandler(object):
# which will be used as a cache key for future resolutions, but # which will be used as a cache key for future resolutions, but
# not get persisted. # not get persisted.
state_group = None state_group = None
new_state_event_ids = frozenset(new_state.itervalues()) new_state_event_ids = frozenset(itervalues(new_state))
for sg, events in state_groups_ids.iteritems(): for sg, events in iteritems(state_groups_ids):
if new_state_event_ids == frozenset(e_id for e_id in events): if new_state_event_ids == frozenset(e_id for e_id in events):
state_group = sg state_group = sg
break break
@ -501,11 +503,11 @@ class StateResolutionHandler(object):
prev_group = None prev_group = None
delta_ids = None delta_ids = None
for old_group, old_ids in state_groups_ids.iteritems(): for old_group, old_ids in iteritems(state_groups_ids):
if not set(new_state) - set(old_ids): if not set(new_state) - set(old_ids):
n_delta_ids = { n_delta_ids = {
k: v k: v
for k, v in new_state.iteritems() for k, v in iteritems(new_state)
if old_ids.get(k) != v if old_ids.get(k) != v
} }
if not delta_ids or len(n_delta_ids) < len(delta_ids): if not delta_ids or len(n_delta_ids) < len(delta_ids):
@ -527,7 +529,7 @@ class StateResolutionHandler(object):
def _ordered_events(events): def _ordered_events(events):
def key_func(e): def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest()
return sorted(events, key=key_func) return sorted(events, key=key_func)
@ -584,7 +586,7 @@ def _seperate(state_sets):
conflicted_state = {} conflicted_state = {}
for state_set in state_sets[1:]: for state_set in state_sets[1:]:
for key, value in state_set.iteritems(): for key, value in iteritems(state_set):
# Check if there is an unconflicted entry for the state key. # Check if there is an unconflicted entry for the state key.
unconflicted_value = unconflicted_state.get(key) unconflicted_value = unconflicted_state.get(key)
if unconflicted_value is None: if unconflicted_value is None:
@ -640,7 +642,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
needed_events = set( needed_events = set(
event_id event_id
for event_ids in conflicted_state.itervalues() for event_ids in itervalues(conflicted_state)
for event_id in event_ids for event_id in event_ids
) )
if event_map is not None: if event_map is not None:
@ -662,7 +664,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
unconflicted_state, conflicted_state, state_map unconflicted_state, conflicted_state, state_map
) )
new_needed_events = set(auth_events.itervalues()) new_needed_events = set(itervalues(auth_events))
new_needed_events -= needed_events new_needed_events -= needed_events
if event_map is not None: if event_map is not None:
new_needed_events -= set(event_map.iterkeys()) new_needed_events -= set(event_map.iterkeys())
@ -679,7 +681,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map): def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
auth_events = {} auth_events = {}
for event_ids in conflicted_state.itervalues(): for event_ids in itervalues(conflicted_state):
for event_id in event_ids: for event_id in event_ids:
if event_id in state_map: if event_id in state_map:
keys = event_auth.auth_types_for_event(state_map[event_id]) keys = event_auth.auth_types_for_event(state_map[event_id])
@ -694,7 +696,7 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma
def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids, def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids,
state_map): state_map):
conflicted_state = {} conflicted_state = {}
for key, event_ids in conflicted_state_ds.iteritems(): for key, event_ids in iteritems(conflicted_state_ds):
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
if len(events) > 1: if len(events) > 1:
conflicted_state[key] = events conflicted_state[key] = events
@ -703,7 +705,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
auth_events = { auth_events = {
key: state_map[ev_id] key: state_map[ev_id]
for key, ev_id in auth_event_ids.iteritems() for key, ev_id in iteritems(auth_event_ids)
if ev_id in state_map if ev_id in state_map
} }
@ -716,7 +718,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
raise raise
new_state = unconflicted_state_ids new_state = unconflicted_state_ids
for key, event in resolved_state.iteritems(): for key, event in iteritems(resolved_state):
new_state[key] = event.event_id new_state[key] = event.event_id
return new_state return new_state
@ -741,7 +743,7 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state) auth_events.update(resolved_state)
for key, events in conflicted_state.iteritems(): for key, events in iteritems(conflicted_state):
if key[0] == EventTypes.JoinRules: if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events) logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = _resolve_auth_events( resolved_state[key] = _resolve_auth_events(
@ -751,7 +753,7 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state) auth_events.update(resolved_state)
for key, events in conflicted_state.iteritems(): for key, events in iteritems(conflicted_state):
if key[0] == EventTypes.Member: if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events) logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = _resolve_auth_events( resolved_state[key] = _resolve_auth_events(
@ -761,7 +763,7 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state) auth_events.update(resolved_state)
for key, events in conflicted_state.iteritems(): for key, events in iteritems(conflicted_state):
if key not in resolved_state: if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events) logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = _resolve_normal_events( resolved_state[key] = _resolve_normal_events(

View file

@ -27,9 +27,17 @@ import sys
import time import time
import threading import threading
from six import itervalues, iterkeys, iteritems
from six.moves import intern, range
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try:
MAX_TXN_ID = sys.maxint - 1
except AttributeError:
# python 3 does not have a maximum int value
MAX_TXN_ID = 2**63 - 1
sql_logger = logging.getLogger("synapse.storage.SQL") sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn") transaction_logger = logging.getLogger("synapse.storage.txn")
perf_logger = logging.getLogger("synapse.storage.TIME") perf_logger = logging.getLogger("synapse.storage.TIME")
@ -134,7 +142,7 @@ class PerformanceCounters(object):
def interval(self, interval_duration, limit=3): def interval(self, interval_duration, limit=3):
counters = [] counters = []
for name, (count, cum_time) in self.current_counters.iteritems(): for name, (count, cum_time) in iteritems(self.current_counters):
prev_count, prev_time = self.previous_counters.get(name, (0, 0)) prev_count, prev_time = self.previous_counters.get(name, (0, 0))
counters.append(( counters.append((
(cum_time - prev_time) / interval_duration, (cum_time - prev_time) / interval_duration,
@ -219,7 +227,7 @@ class SQLBaseStore(object):
# We don't really need these to be unique, so lets stop it from # We don't really need these to be unique, so lets stop it from
# growing really large. # growing really large.
self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1) self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
name = "%s-%x" % (desc, txn_id, ) name = "%s-%x" % (desc, txn_id, )
@ -540,7 +548,7 @@ class SQLBaseStore(object):
", ".join("%s = ?" % (k,) for k in values), ", ".join("%s = ?" % (k,) for k in values),
" AND ".join("%s = ?" % (k,) for k in keyvalues) " AND ".join("%s = ?" % (k,) for k in keyvalues)
) )
sqlargs = values.values() + keyvalues.values() sqlargs = list(values.values()) + list(keyvalues.values())
txn.execute(sql, sqlargs) txn.execute(sql, sqlargs)
if txn.rowcount > 0: if txn.rowcount > 0:
@ -558,7 +566,7 @@ class SQLBaseStore(object):
", ".join(k for k in allvalues), ", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues) ", ".join("?" for _ in allvalues)
) )
txn.execute(sql, allvalues.values()) txn.execute(sql, list(allvalues.values()))
# successfully inserted # successfully inserted
return True return True
@ -626,8 +634,8 @@ class SQLBaseStore(object):
} }
if keyvalues: if keyvalues:
sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys()) sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
txn.execute(sql, keyvalues.values()) txn.execute(sql, list(keyvalues.values()))
else: else:
txn.execute(sql) txn.execute(sql)
@ -691,7 +699,7 @@ class SQLBaseStore(object):
table, table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues) " AND ".join("%s = ?" % (k, ) for k in keyvalues)
) )
txn.execute(sql, keyvalues.values()) txn.execute(sql, list(keyvalues.values()))
else: else:
sql = "SELECT %s FROM %s" % ( sql = "SELECT %s FROM %s" % (
", ".join(retcols), ", ".join(retcols),
@ -722,9 +730,12 @@ class SQLBaseStore(object):
if not iterable: if not iterable:
defer.returnValue(results) defer.returnValue(results)
# iterables can not be sliced, so convert it to a list first
it_list = list(iterable)
chunks = [ chunks = [
iterable[i:i + batch_size] it_list[i:i + batch_size]
for i in xrange(0, len(iterable), batch_size) for i in range(0, len(it_list), batch_size)
] ]
for chunk in chunks: for chunk in chunks:
rows = yield self.runInteraction( rows = yield self.runInteraction(
@ -764,7 +775,7 @@ class SQLBaseStore(object):
) )
values.extend(iterable) values.extend(iterable)
for key, value in keyvalues.iteritems(): for key, value in iteritems(keyvalues):
clauses.append("%s = ?" % (key,)) clauses.append("%s = ?" % (key,))
values.append(value) values.append(value)
@ -787,7 +798,7 @@ class SQLBaseStore(object):
@staticmethod @staticmethod
def _simple_update_txn(txn, table, keyvalues, updatevalues): def _simple_update_txn(txn, table, keyvalues, updatevalues):
if keyvalues: if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys()) where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
else: else:
where = "" where = ""
@ -799,7 +810,7 @@ class SQLBaseStore(object):
txn.execute( txn.execute(
update_sql, update_sql,
updatevalues.values() + keyvalues.values() list(updatevalues.values()) + list(keyvalues.values())
) )
return txn.rowcount return txn.rowcount
@ -847,7 +858,7 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k,) for k in keyvalues) " AND ".join("%s = ?" % (k,) for k in keyvalues)
) )
txn.execute(select_sql, keyvalues.values()) txn.execute(select_sql, list(keyvalues.values()))
row = txn.fetchone() row = txn.fetchone()
if not row: if not row:
@ -885,7 +896,7 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k, ) for k in keyvalues) " AND ".join("%s = ?" % (k, ) for k in keyvalues)
) )
txn.execute(sql, keyvalues.values()) txn.execute(sql, list(keyvalues.values()))
if txn.rowcount == 0: if txn.rowcount == 0:
raise StoreError(404, "No row found") raise StoreError(404, "No row found")
if txn.rowcount > 1: if txn.rowcount > 1:
@ -903,7 +914,7 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k, ) for k in keyvalues) " AND ".join("%s = ?" % (k, ) for k in keyvalues)
) )
return txn.execute(sql, keyvalues.values()) return txn.execute(sql, list(keyvalues.values()))
def _simple_delete_many(self, table, column, iterable, keyvalues, desc): def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
return self.runInteraction( return self.runInteraction(
@ -935,7 +946,7 @@ class SQLBaseStore(object):
) )
values.extend(iterable) values.extend(iterable)
for key, value in keyvalues.iteritems(): for key, value in iteritems(keyvalues):
clauses.append("%s = ?" % (key,)) clauses.append("%s = ?" % (key,))
values.append(value) values.append(value)
@ -975,7 +986,7 @@ class SQLBaseStore(object):
txn.close() txn.close()
if cache: if cache:
min_val = min(cache.itervalues()) min_val = min(itervalues(cache))
else: else:
min_val = max_value min_val = max_value
@ -1090,7 +1101,7 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k,) for k in keyvalues), " AND ".join("%s = ?" % (k,) for k in keyvalues),
" ? ASC LIMIT ? OFFSET ?" " ? ASC LIMIT ? OFFSET ?"
) )
txn.execute(sql, keyvalues.values() + pagevalues) txn.execute(sql, list(keyvalues.values()) + list(pagevalues))
else: else:
sql = "SELECT %s FROM %s ORDER BY %s" % ( sql = "SELECT %s FROM %s ORDER BY %s" % (
", ".join(retcols), ", ".join(retcols),

View file

@ -22,6 +22,8 @@ from . import background_updates
from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches import CACHE_SIZE_FACTOR
from six import iteritems
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -99,7 +101,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
def _update_client_ips_batch_txn(self, txn, to_update): def _update_client_ips_batch_txn(self, txn, to_update):
self.database_engine.lock_table(txn, "user_ips") self.database_engine.lock_table(txn, "user_ips")
for entry in to_update.iteritems(): for entry in iteritems(to_update):
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
self._simple_upsert_txn( self._simple_upsert_txn(
@ -231,5 +233,5 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"user_agent": user_agent, "user_agent": user_agent,
"last_seen": last_seen, "last_seen": last_seen,
} }
for (access_token, ip), (user_agent, last_seen) in results.iteritems() for (access_token, ip), (user_agent, last_seen) in iteritems(results)
)) ))

View file

@ -21,6 +21,7 @@ from synapse.api.errors import StoreError
from ._base import SQLBaseStore, Cache from ._base import SQLBaseStore, Cache
from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
from six import itervalues, iteritems
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -360,7 +361,7 @@ class DeviceStore(SQLBaseStore):
return (now_stream_id, []) return (now_stream_id, [])
if len(query_map) >= 20: if len(query_map) >= 20:
now_stream_id = max(stream_id for stream_id in query_map.itervalues()) now_stream_id = max(stream_id for stream_id in itervalues(query_map))
devices = self._get_e2e_device_keys_txn( devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True txn, query_map.keys(), include_all_devices=True
@ -373,13 +374,13 @@ class DeviceStore(SQLBaseStore):
""" """
results = [] results = []
for user_id, user_devices in devices.iteritems(): for user_id, user_devices in iteritems(devices):
# The prev_id for the first row is always the last row before # The prev_id for the first row is always the last row before
# `from_stream_id` # `from_stream_id`
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id)) txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
rows = txn.fetchall() rows = txn.fetchall()
prev_id = rows[0][0] prev_id = rows[0][0]
for device_id, device in user_devices.iteritems(): for device_id, device in iteritems(user_devices):
stream_id = query_map[(user_id, device_id)] stream_id = query_map[(user_id, device_id)]
result = { result = {
"user_id": user_id, "user_id": user_id,
@ -483,7 +484,7 @@ class DeviceStore(SQLBaseStore):
if devices: if devices:
user_devices = devices[user_id] user_devices = devices[user_id]
results = [] results = []
for device_id, device in user_devices.iteritems(): for device_id, device in iteritems(user_devices):
result = { result = {
"device_id": device_id, "device_id": device_id,
} }

View file

@ -21,6 +21,8 @@ import simplejson as json
from ._base import SQLBaseStore from ._base import SQLBaseStore
from six import iteritems
class EndToEndKeyStore(SQLBaseStore): class EndToEndKeyStore(SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
@ -81,8 +83,8 @@ class EndToEndKeyStore(SQLBaseStore):
query_list, include_all_devices, query_list, include_all_devices,
) )
for user_id, device_keys in results.iteritems(): for user_id, device_keys in iteritems(results):
for device_id, device_info in device_keys.iteritems(): for device_id, device_info in iteritems(device_keys):
device_info["keys"] = json.loads(device_info.pop("key_json")) device_info["keys"] = json.loads(device_info.pop("key_json"))
defer.returnValue(results) defer.returnValue(results)

View file

@ -22,6 +22,8 @@ from synapse.util.caches.descriptors import cachedInlineCallbacks
import logging import logging
import simplejson as json import simplejson as json
from six import iteritems
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -420,7 +422,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.executemany(sql, ( txn.executemany(sql, (
_gen_entry(user_id, actions) _gen_entry(user_id, actions)
for user_id, actions in user_id_actions.iteritems() for user_id, actions in iteritems(user_id_actions)
)) ))
return self.runInteraction( return self.runInteraction(

View file

@ -337,7 +337,7 @@ class EventsWorkerStore(SQLBaseStore):
def _fetch_event_rows(self, txn, events): def _fetch_event_rows(self, txn, events):
rows = [] rows = []
N = 200 N = 200
for i in range(1 + len(events) / N): for i in range(1 + len(events) // N):
evs = events[i * N:(i + 1) * N] evs = events[i * N:(i + 1) * N]
if not evs: if not evs:
break break

View file

@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore):
desc="get_user_filter", desc="get_user_filter",
) )
defer.returnValue(json.loads(str(def_json).decode("utf-8"))) defer.returnValue(json.loads(bytes(def_json).decode("utf-8")))
def add_user_filter(self, user_localpart, user_filter): def add_user_filter(self, user_localpart, user_filter):
def_json = encode_canonical_json(user_filter) def_json = encode_canonical_json(user_filter)

View file

@ -92,7 +92,7 @@ class KeyStore(SQLBaseStore):
if verify_key_bytes: if verify_key_bytes:
defer.returnValue(decode_verify_key_bytes( defer.returnValue(decode_verify_key_bytes(
key_id, str(verify_key_bytes) key_id, bytes(verify_key_bytes)
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -30,6 +30,8 @@ from synapse.types import get_domain_from_id
import logging import logging
import simplejson as json import simplejson as json
from six import itervalues, iteritems
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -272,7 +274,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
users_in_room = {} users_in_room = {}
member_event_ids = [ member_event_ids = [
e_id e_id
for key, e_id in current_state_ids.iteritems() for key, e_id in iteritems(current_state_ids)
if key[0] == EventTypes.Member if key[0] == EventTypes.Member
] ]
@ -289,7 +291,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
users_in_room = dict(prev_res) users_in_room = dict(prev_res)
member_event_ids = [ member_event_ids = [
e_id e_id
for key, e_id in context.delta_ids.iteritems() for key, e_id in iteritems(context.delta_ids)
if key[0] == EventTypes.Member if key[0] == EventTypes.Member
] ]
for etype, state_key in context.delta_ids: for etype, state_key in context.delta_ids:
@ -741,7 +743,7 @@ class _JoinedHostsCache(object):
if state_entry.state_group == self.state_group: if state_entry.state_group == self.state_group:
pass pass
elif state_entry.prev_group == self.state_group: elif state_entry.prev_group == self.state_group:
for (typ, state_key), event_id in state_entry.delta_ids.iteritems(): for (typ, state_key), event_id in iteritems(state_entry.delta_ids):
if typ != EventTypes.Member: if typ != EventTypes.Member:
continue continue
@ -771,7 +773,7 @@ class _JoinedHostsCache(object):
self.state_group = state_entry.state_group self.state_group = state_entry.state_group
else: else:
self.state_group = object() self.state_group = object()
self._len = sum(len(v) for v in self.hosts_to_joined_users.itervalues()) self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users))
defer.returnValue(frozenset(self.hosts_to_joined_users)) defer.returnValue(frozenset(self.hosts_to_joined_users))
def __len__(self): def __len__(self):

View file

@ -684,8 +684,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
results to only those before results to only those before
direction(char): Either 'b' or 'f' to indicate whether we are direction(char): Either 'b' or 'f' to indicate whether we are
paginating forwards or backwards from `from_key`. paginating forwards or backwards from `from_key`.
limit (int): The maximum number of events to return. Zero or less limit (int): The maximum number of events to return.
means no limit.
event_filter (Filter|None): If provided filters the events to event_filter (Filter|None): If provided filters the events to
those that match the filter. those that match the filter.
@ -694,6 +693,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
as a list of _EventDictReturn and a token that points to the end as a list of _EventDictReturn and a token that points to the end
of the result set. of the result set.
""" """
assert int(limit) >= 0
# Tokens really represent positions between elements, but we use # Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence # the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities. # we have a bit of asymmetry when it comes to equalities.
@ -723,22 +725,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
bounds += " AND " + filter_clause bounds += " AND " + filter_clause
args.extend(filter_args) args.extend(filter_args)
if int(limit) > 0:
args.append(int(limit)) args.append(int(limit))
limit_str = " LIMIT ?"
else:
limit_str = ""
sql = ( sql = (
"SELECT event_id, topological_ordering, stream_ordering" "SELECT event_id, topological_ordering, stream_ordering"
" FROM events" " FROM events"
" WHERE outlier = ? AND room_id = ? AND %(bounds)s" " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
" ORDER BY topological_ordering %(order)s," " ORDER BY topological_ordering %(order)s,"
" stream_ordering %(order)s %(limit)s" " stream_ordering %(order)s LIMIT ?"
) % { ) % {
"bounds": bounds, "bounds": bounds,
"order": order, "order": order,
"limit": limit_str
} }
txn.execute(sql, args) txn.execute(sql, args)

View file

@ -20,6 +20,8 @@ from twisted.internet import defer, reactor, task
import time import time
import logging import logging
from itertools import islice
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -79,3 +81,19 @@ class Clock(object):
except Exception: except Exception:
if not ignore_errs: if not ignore_errs:
raise raise
def batch_iter(iterable, size):
"""batch an iterable up into tuples with a maximum size
Args:
iterable (iterable): the iterable to slice
size (int): the maximum batch size
Returns:
an iterator over the chunks
"""
# make sure we can deal with iterables like lists too
sourceiter = iter(iterable)
# call islice until it returns an empty tuple
return iter(lambda: tuple(islice(sourceiter, size)), ())

View file

@ -17,6 +17,9 @@ from prometheus_client.core import Gauge, REGISTRY, GaugeMetricFamily
import os import os
from six.moves import intern
import six
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5)) CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5))
caches_by_name = {} caches_by_name = {}
@ -110,7 +113,9 @@ def intern_string(string):
return None return None
try: try:
if six.PY2:
string = string.encode("ascii") string = string.encode("ascii")
return intern(string) return intern(string)
except UnicodeEncodeError: except UnicodeEncodeError:
return string return string