Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2017-03-24 13:59:35 +00:00
commit 8224121502
53 changed files with 1085 additions and 643 deletions

View file

@ -1,3 +1,50 @@
Changes in synapse v0.19.3 (2017-03-20)
=======================================
No changes since v0.19.3-rc2
Changes in synapse v0.19.3-rc2 (2017-03-13)
===========================================
Bug fixes:
* Fix bug in handling of incoming device list updates over federation.
Changes in synapse v0.19.3-rc1 (2017-03-08)
===========================================
Features:
* Add some administration functionalities. Thanks to morteza-araby! (PR #1784)
Changes:
* Reduce database table sizes (PR #1873, #1916, #1923, #1963)
* Update contrib/ to not use syutil. Thanks to andrewshadura! (PR #1907)
* Don't fetch current state when sending an event in common case (PR #1955)
Bug fixes:
* Fix synapse_port_db failure. Thanks to Pneumaticat! (PR #1904)
* Fix caching to not cache error responses (PR #1913)
* Fix APIs to make kick & ban reasons work (PR #1917)
* Fix bugs in the /keys/changes api (PR #1921)
* Fix bug where users couldn't forget rooms they were banned from (PR #1922)
* Fix issue with long language values in pushers API (PR #1925)
* Fix a race in transaction queue (PR #1930)
* Fix dynamic thumbnailing to preserve aspect ratio. Thanks to jkolo! (PR
#1945)
* Fix device list update to not constantly resync (PR #1964)
* Fix potential for huge memory usage when getting device that have
changed (PR #1969)
Changes in synapse v0.19.2 (2017-02-20) Changes in synapse v0.19.2 (2017-02-20)
======================================= =======================================

View file

@ -20,7 +20,7 @@ The overall architecture is::
https://somewhere.org/_matrix https://elsewhere.net/_matrix https://somewhere.org/_matrix https://elsewhere.net/_matrix
``#matrix:matrix.org`` is the official support room for Matrix, and can be ``#matrix:matrix.org`` is the official support room for Matrix, and can be
accessed by any client from https://matrix.org/docs/projects/try-matrix-now or accessed by any client from https://matrix.org/docs/projects/try-matrix-now.html or
via IRC bridge at irc://irc.freenode.net/matrix. via IRC bridge at irc://irc.freenode.net/matrix.
Synapse is currently in rapid development, but as of version 0.5 we believe it Synapse is currently in rapid development, but as of version 0.5 we believe it
@ -68,7 +68,7 @@ or mandatory service provider in Matrix, unlike WhatsApp, Facebook, Hangouts,
etc. etc.
We'd like to invite you to join #matrix:matrix.org (via We'd like to invite you to join #matrix:matrix.org (via
https://matrix.org/docs/projects/try-matrix-now), run a homeserver, take a look https://matrix.org/docs/projects/try-matrix-now.html), run a homeserver, take a look
at the `Matrix spec <https://matrix.org/docs/spec>`_, and experiment with the at the `Matrix spec <https://matrix.org/docs/spec>`_, and experiment with the
`APIs <https://matrix.org/docs/api>`_ and `Client SDKs `APIs <https://matrix.org/docs/api>`_ and `Client SDKs
<http://matrix.org/docs/projects/try-matrix-now.html#client-sdks>`_. <http://matrix.org/docs/projects/try-matrix-now.html#client-sdks>`_.
@ -321,7 +321,7 @@ Debian
Matrix provides official Debian packages via apt from http://matrix.org/packages/debian/. Matrix provides official Debian packages via apt from http://matrix.org/packages/debian/.
Note that these packages do not include a client - choose one from Note that these packages do not include a client - choose one from
https://matrix.org/docs/projects/try-matrix-now/ (or build your own with one of our SDKs :) https://matrix.org/docs/projects/try-matrix-now.html (or build your own with one of our SDKs :)
Fedora Fedora
------ ------
@ -808,7 +808,7 @@ directory of your choice::
Synapse has a number of external dependencies, that are easiest Synapse has a number of external dependencies, that are easiest
to install using pip and a virtualenv:: to install using pip and a virtualenv::
virtualenv env virtualenv -p python2.7 env
source env/bin/activate source env/bin/activate
python synapse/python_dependencies.py | xargs pip install python synapse/python_dependencies.py | xargs pip install
pip install lxml mock pip install lxml mock

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.19.2" __version__ = "0.19.3"

View file

@ -23,7 +23,7 @@ from synapse import event_auth
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes from synapse.api.errors import AuthError, Codes
from synapse.types import UserID from synapse.types import UserID
from synapse.util.logcontext import preserve_context_over_fn from synapse.util import logcontext
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -209,8 +209,7 @@ class Auth(object):
default=[""] default=[""]
)[0] )[0]
if user and access_token and ip_addr: if user and access_token and ip_addr:
preserve_context_over_fn( logcontext.preserve_fn(self.store.insert_client_ip)(
self.store.insert_client_ip,
user=user, user=user,
access_token=access_token, access_token=access_token,
ip=ip_addr, ip=ip_addr,

View file

@ -15,10 +15,172 @@
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage.presence import UserPresenceState from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, RoomID from synapse.types import UserID, RoomID
from twisted.internet import defer from twisted.internet import defer
import ujson as json import ujson as json
import jsonschema
from jsonschema import FormatChecker
FILTER_SCHEMA = {
"additionalProperties": False,
"type": "object",
"properties": {
"limit": {
"type": "number"
},
"senders": {
"$ref": "#/definitions/user_id_array"
},
"not_senders": {
"$ref": "#/definitions/user_id_array"
},
# TODO: We don't limit event type values but we probably should...
# check types are valid event types
"types": {
"type": "array",
"items": {
"type": "string"
}
},
"not_types": {
"type": "array",
"items": {
"type": "string"
}
}
}
}
ROOM_FILTER_SCHEMA = {
"additionalProperties": False,
"type": "object",
"properties": {
"not_rooms": {
"$ref": "#/definitions/room_id_array"
},
"rooms": {
"$ref": "#/definitions/room_id_array"
},
"ephemeral": {
"$ref": "#/definitions/room_event_filter"
},
"include_leave": {
"type": "boolean"
},
"state": {
"$ref": "#/definitions/room_event_filter"
},
"timeline": {
"$ref": "#/definitions/room_event_filter"
},
"account_data": {
"$ref": "#/definitions/room_event_filter"
},
}
}
ROOM_EVENT_FILTER_SCHEMA = {
"additionalProperties": False,
"type": "object",
"properties": {
"limit": {
"type": "number"
},
"senders": {
"$ref": "#/definitions/user_id_array"
},
"not_senders": {
"$ref": "#/definitions/user_id_array"
},
"types": {
"type": "array",
"items": {
"type": "string"
}
},
"not_types": {
"type": "array",
"items": {
"type": "string"
}
},
"rooms": {
"$ref": "#/definitions/room_id_array"
},
"not_rooms": {
"$ref": "#/definitions/room_id_array"
},
"contains_url": {
"type": "boolean"
}
}
}
USER_ID_ARRAY_SCHEMA = {
"type": "array",
"items": {
"type": "string",
"format": "matrix_user_id"
}
}
ROOM_ID_ARRAY_SCHEMA = {
"type": "array",
"items": {
"type": "string",
"format": "matrix_room_id"
}
}
USER_FILTER_SCHEMA = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": "schema for a Sync filter",
"type": "object",
"definitions": {
"room_id_array": ROOM_ID_ARRAY_SCHEMA,
"user_id_array": USER_ID_ARRAY_SCHEMA,
"filter": FILTER_SCHEMA,
"room_filter": ROOM_FILTER_SCHEMA,
"room_event_filter": ROOM_EVENT_FILTER_SCHEMA
},
"properties": {
"presence": {
"$ref": "#/definitions/filter"
},
"account_data": {
"$ref": "#/definitions/filter"
},
"room": {
"$ref": "#/definitions/room_filter"
},
"event_format": {
"type": "string",
"enum": ["client", "federation"]
},
"event_fields": {
"type": "array",
"items": {
"type": "string",
# Don't allow '\\' in event field filters. This makes matching
# events a lot easier as we can then use a negative lookbehind
# assertion to split '\.' If we allowed \\ then it would
# incorrectly split '\\.' See synapse.events.utils.serialize_event
"pattern": "^((?!\\\).)*$"
}
}
},
"additionalProperties": False
}
@FormatChecker.cls_checks('matrix_room_id')
def matrix_room_id_validator(room_id_str):
return RoomID.from_string(room_id_str)
@FormatChecker.cls_checks('matrix_user_id')
def matrix_user_id_validator(user_id_str):
return UserID.from_string(user_id_str)
class Filtering(object): class Filtering(object):
@ -53,98 +215,11 @@ class Filtering(object):
# NB: Filters are the complete json blobs. "Definitions" are an # NB: Filters are the complete json blobs. "Definitions" are an
# individual top-level key e.g. public_user_data. Filters are made of # individual top-level key e.g. public_user_data. Filters are made of
# many definitions. # many definitions.
try:
top_level_definitions = [ jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA,
"presence", "account_data" format_checker=FormatChecker())
] except jsonschema.ValidationError as e:
raise SynapseError(400, e.message)
room_level_definitions = [
"state", "timeline", "ephemeral", "account_data"
]
for key in top_level_definitions:
if key in user_filter_json:
self._check_definition(user_filter_json[key])
if "room" in user_filter_json:
self._check_definition_room_lists(user_filter_json["room"])
for key in room_level_definitions:
if key in user_filter_json["room"]:
self._check_definition(user_filter_json["room"][key])
if "event_fields" in user_filter_json:
if type(user_filter_json["event_fields"]) != list:
raise SynapseError(400, "event_fields must be a list of strings")
for field in user_filter_json["event_fields"]:
if not isinstance(field, basestring):
raise SynapseError(400, "Event field must be a string")
# Don't allow '\\' in event field filters. This makes matching
# events a lot easier as we can then use a negative lookbehind
# assertion to split '\.' If we allowed \\ then it would
# incorrectly split '\\.' See synapse.events.utils.serialize_event
if r'\\' in field:
raise SynapseError(
400, r'The escape character \ cannot itself be escaped'
)
def _check_definition_room_lists(self, definition):
"""Check that "rooms" and "not_rooms" are lists of room ids if they
are present
Args:
definition(dict): The filter definition
Raises:
SynapseError: If there was a problem with this definition.
"""
# check rooms are valid room IDs
room_id_keys = ["rooms", "not_rooms"]
for key in room_id_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for room_id in definition[key]:
RoomID.from_string(room_id)
def _check_definition(self, definition):
"""Check if the provided definition is valid.
This inspects not only the types but also the values to make sure they
make sense.
Args:
definition(dict): The filter definition
Raises:
SynapseError: If there was a problem with this definition.
"""
# NB: Filters are the complete json blobs. "Definitions" are an
# individual top-level key e.g. public_user_data. Filters are made of
# many definitions.
if type(definition) != dict:
raise SynapseError(
400, "Expected JSON object, not %s" % (definition,)
)
self._check_definition_room_lists(definition)
# check senders are valid user IDs
user_id_keys = ["senders", "not_senders"]
for key in user_id_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for user_id in definition[key]:
UserID.from_string(user_id)
# TODO: We don't limit event type values but we probably should...
# check types are valid event types
event_keys = ["types", "not_types"]
for key in event_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for event_type in definition[key]:
if not isinstance(event_type, basestring):
raise SynapseError(400, "Event type should be a string")
class FilterCollection(object): class FilterCollection(object):

View file

@ -29,6 +29,7 @@ from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.rest.client.v1.room import PublicRoomListRestServlet from synapse.rest.client.v1.room import PublicRoomListRestServlet
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore from synapse.storage.client_ips import ClientIpStore
@ -63,6 +64,7 @@ class ClientReaderSlavedStore(
DirectoryStore, DirectoryStore,
SlavedApplicationServiceStore, SlavedApplicationServiceStore,
SlavedRegistrationStore, SlavedRegistrationStore,
TransactionStore,
BaseSlavedStore, BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different ClientIpStore, # After BaseSlavedStore because the constructor is different
): ):

View file

@ -24,6 +24,7 @@ from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.server import HomeServer from synapse.server import HomeServer
@ -59,6 +60,7 @@ logger = logging.getLogger("synapse.app.media_repository")
class MediaRepositorySlavedStore( class MediaRepositorySlavedStore(
SlavedApplicationServiceStore, SlavedApplicationServiceStore,
SlavedRegistrationStore, SlavedRegistrationStore,
TransactionStore,
BaseSlavedStore, BaseSlavedStore,
MediaRepositoryStore, MediaRepositoryStore,
ClientIpStore, ClientIpStore,

View file

@ -15,7 +15,6 @@
from synapse.crypto.keyclient import fetch_server_key from synapse.crypto.keyclient import fetch_server_key
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import ( from synapse.util.logcontext import (
@ -238,8 +237,14 @@ class Keyring(object):
d.addBoth(rm, server_name) d.addBoth(rm, server_name)
def get_server_verify_keys(self, verify_requests): def get_server_verify_keys(self, verify_requests):
"""Takes a dict of KeyGroups and tries to find at least one key for """Tries to find at least one key for each verify request
each group.
For each verify_request, verify_request.deferred is called back with
params (server_name, key_id, VerifyKey) if a key is found, or errbacked
with a SynapseError if none of the keys are found.
Args:
verify_requests (list[VerifyKeyRequest]): list of verify requests
""" """
# These are functions that produce keys given a list of key ids # These are functions that produce keys given a list of key ids
@ -252,8 +257,11 @@ class Keyring(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def do_iterations(): def do_iterations():
with Measure(self.clock, "get_server_verify_keys"): with Measure(self.clock, "get_server_verify_keys"):
# dict[str, dict[str, VerifyKey]]: results so far.
# map server_name -> key_id -> VerifyKey
merged_results = {} merged_results = {}
# dict[str, set(str)]: keys to fetch for each server
missing_keys = {} missing_keys = {}
for verify_request in verify_requests: for verify_request in verify_requests:
missing_keys.setdefault(verify_request.server_name, set()).update( missing_keys.setdefault(verify_request.server_name, set()).update(
@ -315,6 +323,16 @@ class Keyring(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids): def get_keys_from_store(self, server_name_and_key_ids):
"""
Args:
server_name_and_key_ids (list[(str, iterable[str])]):
list of (server_name, iterable[key_id]) tuples to fetch keys for
Returns:
Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from
server_name -> key_id -> VerifyKey
"""
res = yield preserve_context_over_deferred(defer.gatherResults( res = yield preserve_context_over_deferred(defer.gatherResults(
[ [
preserve_fn(self.store.get_server_verify_keys)( preserve_fn(self.store.get_server_verify_keys)(
@ -363,30 +381,24 @@ class Keyring(object):
def get_keys_from_server(self, server_name_and_key_ids): def get_keys_from_server(self, server_name_and_key_ids):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_key(server_name, key_ids): def get_key(server_name, key_ids):
limiter = yield get_retry_limiter( keys = None
server_name, try:
self.clock, keys = yield self.get_server_verify_key_v2_direct(
self.store, server_name, key_ids
) )
with limiter: except Exception as e:
keys = None logger.info(
try: "Unable to get key %r for %r directly: %s %s",
keys = yield self.get_server_verify_key_v2_direct( key_ids, server_name,
server_name, key_ids type(e).__name__, str(e.message),
) )
except Exception as e:
logger.info(
"Unable to get key %r for %r directly: %s %s",
key_ids, server_name,
type(e).__name__, str(e.message),
)
if not keys: if not keys:
keys = yield self.get_server_verify_key_v1_direct( keys = yield self.get_server_verify_key_v1_direct(
server_name, key_ids server_name, key_ids
) )
keys = {server_name: keys} keys = {server_name: keys}
defer.returnValue(keys) defer.returnValue(keys)

View file

@ -29,7 +29,7 @@ from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.events import FrozenEvent, builder from synapse.events import FrozenEvent, builder
import synapse.metrics import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
import copy import copy
import itertools import itertools
@ -88,7 +88,7 @@ class FederationClient(FederationBase):
@log_function @log_function
def make_query(self, destination, query_type, args, def make_query(self, destination, query_type, args,
retry_on_dns_fail=False): retry_on_dns_fail=False, ignore_backoff=False):
"""Sends a federation Query to a remote homeserver of the given type """Sends a federation Query to a remote homeserver of the given type
and arguments. and arguments.
@ -98,6 +98,8 @@ class FederationClient(FederationBase):
handler name used in register_query_handler(). handler name used in register_query_handler().
args (dict): Mapping of strings to strings containing the details args (dict): Mapping of strings to strings containing the details
of the query request. of the query request.
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
Returns: Returns:
a Deferred which will eventually yield a JSON object from the a Deferred which will eventually yield a JSON object from the
@ -106,7 +108,8 @@ class FederationClient(FederationBase):
sent_queries_counter.inc(query_type) sent_queries_counter.inc(query_type)
return self.transport_layer.make_query( return self.transport_layer.make_query(
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail,
ignore_backoff=ignore_backoff,
) )
@log_function @log_function
@ -234,31 +237,24 @@ class FederationClient(FederationBase):
continue continue
try: try:
limiter = yield get_retry_limiter( transaction_data = yield self.transport_layer.get_event(
destination, destination, event_id, timeout=timeout,
self._clock,
self.store,
) )
with limiter: logger.debug("transaction_data %r", transaction_data)
transaction_data = yield self.transport_layer.get_event(
destination, event_id, timeout=timeout,
)
logger.debug("transaction_data %r", transaction_data) pdu_list = [
self.event_from_pdu_json(p, outlier=outlier)
for p in transaction_data["pdus"]
]
pdu_list = [ if pdu_list and pdu_list[0]:
self.event_from_pdu_json(p, outlier=outlier) pdu = pdu_list[0]
for p in transaction_data["pdus"]
]
if pdu_list and pdu_list[0]: # Check signatures are correct.
pdu = pdu_list[0] signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
# Check signatures are correct. break
signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
break
pdu_attempts[destination] = now pdu_attempts[destination] = now

View file

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import datetime
from twisted.internet import defer from twisted.internet import defer
@ -22,9 +22,7 @@ from .units import Transaction, Edu
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.retryutils import ( from synapse.util.retryutils import NotRetryingDestination
get_retry_limiter, NotRetryingDestination,
)
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
@ -312,13 +310,6 @@ class TransactionQueue(object):
yield run_on_reactor() yield run_on_reactor()
while True: while True:
limiter = yield get_retry_limiter(
destination,
self.clock,
self.store,
backoff_on_404=True, # If we get a 404 the other side has gone
)
device_message_edus, device_stream_id, dev_list_id = ( device_message_edus, device_stream_id, dev_list_id = (
yield self._get_new_device_messages(destination) yield self._get_new_device_messages(destination)
) )
@ -374,7 +365,6 @@ class TransactionQueue(object):
success = yield self._send_new_transaction( success = yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures, destination, pending_pdus, pending_edus, pending_failures,
limiter=limiter,
) )
if success: if success:
# Remove the acknowledged device messages from the database # Remove the acknowledged device messages from the database
@ -392,12 +382,24 @@ class TransactionQueue(object):
self.last_device_list_stream_id_by_dest[destination] = dev_list_id self.last_device_list_stream_id_by_dest[destination] = dev_list_id
else: else:
break break
except NotRetryingDestination: except NotRetryingDestination as e:
logger.debug( logger.debug(
"TX [%s] not ready for retry yet - " "TX [%s] not ready for retry yet (next retry at %s) - "
"dropping transaction for now", "dropping transaction for now",
destination, destination,
datetime.datetime.fromtimestamp(
(e.retry_last_ts + e.retry_interval) / 1000.0
),
) )
except Exception as e:
logger.warn(
"TX [%s] Failed to send transaction: %s",
destination,
e,
)
for p in pending_pdus:
logger.info("Failed to send event %s to %s", p.event_id,
destination)
finally: finally:
# We want to be *very* sure we delete this after we stop processing # We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None) self.pending_transactions.pop(destination, None)
@ -437,7 +439,7 @@ class TransactionQueue(object):
@measure_func("_send_new_transaction") @measure_func("_send_new_transaction")
@defer.inlineCallbacks @defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus, def _send_new_transaction(self, destination, pending_pdus, pending_edus,
pending_failures, limiter): pending_failures):
# Sort based on the order field # Sort based on the order field
pending_pdus.sort(key=lambda t: t[1]) pending_pdus.sort(key=lambda t: t[1])
@ -447,132 +449,104 @@ class TransactionQueue(object):
success = True success = True
logger.debug("TX [%s] _attempt_new_transaction", destination)
txn_id = str(self._next_txn_id)
logger.debug(
"TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)",
destination, txn_id,
len(pdus),
len(edus),
len(failures)
)
logger.debug("TX [%s] Persisting transaction...", destination)
transaction = Transaction.create_new(
origin_server_ts=int(self.clock.time_msec()),
transaction_id=txn_id,
origin=self.server_name,
destination=destination,
pdus=pdus,
edus=edus,
pdu_failures=failures,
)
self._next_txn_id += 1
yield self.transaction_actions.prepare_to_send(transaction)
logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
"TX [%s] {%s} Sending transaction [%s],"
" (PDUs: %d, EDUs: %d, failures: %d)",
destination, txn_id,
transaction.transaction_id,
len(pdus),
len(edus),
len(failures),
)
# Actually send the transaction
# FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work
def json_data_cb():
data = transaction.get_dict()
now = int(self.clock.time_msec())
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
unsigned = p.setdefault("unsigned", {})
unsigned["age"] = now - int(p["age_ts"])
del p["age_ts"]
return data
try: try:
logger.debug("TX [%s] _attempt_new_transaction", destination) response = yield self.transport_layer.send_transaction(
transaction, json_data_cb
txn_id = str(self._next_txn_id)
logger.debug(
"TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)",
destination, txn_id,
len(pdus),
len(edus),
len(failures)
) )
code = 200
logger.debug("TX [%s] Persisting transaction...", destination) if response:
for e_id, r in response.get("pdus", {}).items():
transaction = Transaction.create_new( if "error" in r:
origin_server_ts=int(self.clock.time_msec()), logger.warn(
transaction_id=txn_id, "Transaction returned error for %s: %s",
origin=self.server_name, e_id, r,
destination=destination,
pdus=pdus,
edus=edus,
pdu_failures=failures,
)
self._next_txn_id += 1
yield self.transaction_actions.prepare_to_send(transaction)
logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
"TX [%s] {%s} Sending transaction [%s],"
" (PDUs: %d, EDUs: %d, failures: %d)",
destination, txn_id,
transaction.transaction_id,
len(pdus),
len(edus),
len(failures),
)
with limiter:
# Actually send the transaction
# FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work
def json_data_cb():
data = transaction.get_dict()
now = int(self.clock.time_msec())
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
unsigned = p.setdefault("unsigned", {})
unsigned["age"] = now - int(p["age_ts"])
del p["age_ts"]
return data
try:
response = yield self.transport_layer.send_transaction(
transaction, json_data_cb
)
code = 200
if response:
for e_id, r in response.get("pdus", {}).items():
if "error" in r:
logger.warn(
"Transaction returned error for %s: %s",
e_id, r,
)
except HttpResponseException as e:
code = e.code
response = e.response
if e.code in (401, 404, 429) or 500 <= e.code:
logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
) )
raise e except HttpResponseException as e:
code = e.code
response = e.response
if e.code in (401, 404, 429) or 500 <= e.code:
logger.info( logger.info(
"TX [%s] {%s} got %d response", "TX [%s] {%s} got %d response",
destination, txn_id, code destination, txn_id, code
) )
raise e
logger.debug("TX [%s] Sent transaction", destination) logger.info(
logger.debug("TX [%s] Marking as delivered...", destination) "TX [%s] {%s} got %d response",
destination, txn_id, code
)
yield self.transaction_actions.delivered( logger.debug("TX [%s] Sent transaction", destination)
transaction, code, response logger.debug("TX [%s] Marking as delivered...", destination)
)
logger.debug("TX [%s] Marked as delivered", destination) yield self.transaction_actions.delivered(
transaction, code, response
)
if code != 200: logger.debug("TX [%s] Marked as delivered", destination)
for p in pdus:
logger.info(
"Failed to send event %s to %s", p.event_id, destination
)
success = False
except RuntimeError as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
success = False
if code != 200:
for p in pdus: for p in pdus:
logger.info("Failed to send event %s to %s", p.event_id, destination) logger.info(
except Exception as e: "Failed to send event %s to %s", p.event_id, destination
# We capture this here as there as nothing actually listens )
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
success = False success = False
for p in pdus:
logger.info("Failed to send event %s to %s", p.event_id, destination)
defer.returnValue(success) defer.returnValue(success)

View file

@ -163,6 +163,7 @@ class TransportLayerClient(object):
data=json_data, data=json_data,
json_data_callback=json_data_callback, json_data_callback=json_data_callback,
long_retries=True, long_retries=True,
backoff_on_404=True, # If we get a 404 the other side has gone
) )
logger.debug( logger.debug(
@ -174,7 +175,8 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def make_query(self, destination, query_type, args, retry_on_dns_fail): def make_query(self, destination, query_type, args, retry_on_dns_fail,
ignore_backoff=False):
path = PREFIX + "/query/%s" % query_type path = PREFIX + "/query/%s" % query_type
content = yield self.client.get_json( content = yield self.client.get_json(
@ -183,6 +185,7 @@ class TransportLayerClient(object):
args=args, args=args,
retry_on_dns_fail=retry_on_dns_fail, retry_on_dns_fail=retry_on_dns_fail,
timeout=10000, timeout=10000,
ignore_backoff=ignore_backoff,
) )
defer.returnValue(content) defer.returnValue(content)
@ -242,6 +245,7 @@ class TransportLayerClient(object):
destination=destination, destination=destination,
path=path, path=path,
data=content, data=content,
ignore_backoff=True,
) )
defer.returnValue(response) defer.returnValue(response)
@ -269,6 +273,7 @@ class TransportLayerClient(object):
destination=remote_server, destination=remote_server,
path=path, path=path,
args=args, args=args,
ignore_backoff=True,
) )
defer.returnValue(response) defer.returnValue(response)

View file

@ -175,6 +175,7 @@ class DirectoryHandler(BaseHandler):
"room_alias": room_alias.to_string(), "room_alias": room_alias.to_string(),
}, },
retry_on_dns_fail=False, retry_on_dns_fail=False,
ignore_backoff=True,
) )
except CodeMessageException as e: except CodeMessageException as e:
logging.warn("Error retrieving alias") logging.warn("Error retrieving alias")

View file

@ -22,7 +22,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, CodeMessageException from synapse.api.errors import SynapseError, CodeMessageException
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -121,15 +121,11 @@ class E2eKeysHandler(object):
def do_remote_query(destination): def do_remote_query(destination):
destination_query = remote_queries_not_in_cache[destination] destination_query = remote_queries_not_in_cache[destination]
try: try:
limiter = yield get_retry_limiter( remote_result = yield self.federation.query_client_keys(
destination, self.clock, self.store destination,
{"device_keys": destination_query},
timeout=timeout
) )
with limiter:
remote_result = yield self.federation.query_client_keys(
destination,
{"device_keys": destination_query},
timeout=timeout
)
for user_id, keys in remote_result["device_keys"].items(): for user_id, keys in remote_result["device_keys"].items():
if user_id in destination_query: if user_id in destination_query:
@ -239,18 +235,14 @@ class E2eKeysHandler(object):
def claim_client_keys(destination): def claim_client_keys(destination):
device_keys = remote_queries[destination] device_keys = remote_queries[destination]
try: try:
limiter = yield get_retry_limiter( remote_result = yield self.federation.claim_client_keys(
destination, self.clock, self.store destination,
{"one_time_keys": device_keys},
timeout=timeout
) )
with limiter: for user_id, keys in remote_result["one_time_keys"].items():
remote_result = yield self.federation.claim_client_keys( if user_id in device_keys:
destination, json_result[user_id] = keys
{"one_time_keys": device_keys},
timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
except CodeMessageException as e: except CodeMessageException as e:
failures[destination] = { failures[destination] = {
"status": e.code, "message": e.message "status": e.code, "message": e.message
@ -316,7 +308,7 @@ class E2eKeysHandler(object):
# old access_token without an associated device_id. Either way, we # old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with # need to double-check the device is registered to avoid ending up with
# keys without a corresponding device. # keys without a corresponding device.
self.device_handler.check_device_registered(user_id, device_id) yield self.device_handler.check_device_registered(user_id, device_id)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id) result = yield self.store.count_e2e_one_time_keys(user_id, device_id)

View file

@ -578,8 +578,7 @@ class PresenceHandler(object):
if not local_states: if not local_states:
continue continue
users = yield self.store.get_users_in_room(room_id) hosts = yield self.store.get_hosts_in_room(room_id)
hosts = set(get_domain_from_id(u) for u in users)
for host in hosts: for host in hosts:
hosts_to_states.setdefault(host, []).extend(local_states) hosts_to_states.setdefault(host, []).extend(local_states)

View file

@ -52,7 +52,8 @@ class ProfileHandler(BaseHandler):
args={ args={
"user_id": target_user.to_string(), "user_id": target_user.to_string(),
"field": "displayname", "field": "displayname",
} },
ignore_backoff=True,
) )
except CodeMessageException as e: except CodeMessageException as e:
if e.code != 404: if e.code != 404:
@ -99,7 +100,8 @@ class ProfileHandler(BaseHandler):
args={ args={
"user_id": target_user.to_string(), "user_id": target_user.to_string(),
"field": "avatar_url", "field": "avatar_url",
} },
ignore_backoff=True,
) )
except CodeMessageException as e: except CodeMessageException as e:
if e.code != 404: if e.code != 404:

View file

@ -12,8 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import synapse.util.retryutils
from twisted.internet import defer, reactor, protocol from twisted.internet import defer, reactor, protocol
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.web.client import readBody, HTTPConnectionPool, Agent from twisted.web.client import readBody, HTTPConnectionPool, Agent
@ -22,7 +21,7 @@ from twisted.web._newclient import ResponseDone
from synapse.http.endpoint import matrix_federation_endpoint from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.async import sleep from synapse.util.async import sleep
from synapse.util.logcontext import preserve_context_over_fn from synapse.util import logcontext
import synapse.metrics import synapse.metrics
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
@ -94,6 +93,7 @@ class MatrixFederationHttpClient(object):
reactor, MatrixFederationEndpointFactory(hs), pool=pool reactor, MatrixFederationEndpointFactory(hs), pool=pool
) )
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._store = hs.get_datastore()
self.version_string = hs.version_string self.version_string = hs.version_string
self._next_id = 1 self._next_id = 1
@ -103,129 +103,152 @@ class MatrixFederationHttpClient(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_request(self, destination, method, path_bytes, def _request(self, destination, method, path,
body_callback, headers_dict={}, param_bytes=b"", body_callback, headers_dict={}, param_bytes=b"",
query_bytes=b"", retry_on_dns_fail=True, query_bytes=b"", retry_on_dns_fail=True,
timeout=None, long_retries=False): timeout=None, long_retries=False,
""" Creates and sends a request to the given url ignore_backoff=False,
backoff_on_404=False):
""" Creates and sends a request to the given server
Args:
destination (str): The remote server to send the HTTP request to.
method (str): HTTP method
path (str): The HTTP path
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
backoff_on_404 (bool): Back off if we get a 404
Returns: Returns:
Deferred: resolves with the http response object on success. Deferred: resolves with the http response object on success.
Fails with ``HTTPRequestException``: if we get an HTTP response Fails with ``HTTPRequestException``: if we get an HTTP response
code >= 300. code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
""" """
headers_dict[b"User-Agent"] = [self.version_string] limiter = yield synapse.util.retryutils.get_retry_limiter(
headers_dict[b"Host"] = [destination] destination,
self.clock,
url_bytes = self._create_url( self._store,
destination, path_bytes, param_bytes, query_bytes backoff_on_404=backoff_on_404,
ignore_backoff=ignore_backoff,
) )
txn_id = "%s-O-%s" % (method, self._next_id) destination = destination.encode("ascii")
self._next_id = (self._next_id + 1) % (sys.maxint - 1) path_bytes = path.encode("ascii")
with limiter:
headers_dict[b"User-Agent"] = [self.version_string]
headers_dict[b"Host"] = [destination]
outbound_logger.info( url_bytes = self._create_url(
"{%s} [%s] Sending request: %s %s", destination, path_bytes, param_bytes, query_bytes
txn_id, destination, method, url_bytes )
)
# XXX: Would be much nicer to retry only at the transaction-layer txn_id = "%s-O-%s" % (method, self._next_id)
# (once we have reliable transactions in place) self._next_id = (self._next_id + 1) % (sys.maxint - 1)
if long_retries:
retries_left = MAX_LONG_RETRIES
else:
retries_left = MAX_SHORT_RETRIES
http_url_bytes = urlparse.urlunparse( outbound_logger.info(
("", "", path_bytes, param_bytes, query_bytes, "") "{%s} [%s] Sending request: %s %s",
) txn_id, destination, method, url_bytes
)
log_result = None # XXX: Would be much nicer to retry only at the transaction-layer
try: # (once we have reliable transactions in place)
while True: if long_retries:
producer = None retries_left = MAX_LONG_RETRIES
if body_callback: else:
producer = body_callback(method, http_url_bytes, headers_dict) retries_left = MAX_SHORT_RETRIES
try: http_url_bytes = urlparse.urlunparse(
def send_request(): ("", "", path_bytes, param_bytes, query_bytes, "")
request_deferred = preserve_context_over_fn( )
self.agent.request,
log_result = None
try:
while True:
producer = None
if body_callback:
producer = body_callback(method, http_url_bytes, headers_dict)
try:
def send_request():
request_deferred = self.agent.request(
method,
url_bytes,
Headers(headers_dict),
producer
)
return self.clock.time_bound_deferred(
request_deferred,
time_out=timeout / 1000. if timeout else 60,
)
with logcontext.PreserveLoggingContext():
response = yield send_request()
log_result = "%d %s" % (response.code, response.phrase,)
break
except Exception as e:
if not retry_on_dns_fail and isinstance(e, DNSLookupError):
logger.warn(
"DNS Lookup failed to %s with %s",
destination,
e
)
log_result = "DNS Lookup failed to %s with %s" % (
destination, e
)
raise
logger.warn(
"{%s} Sending request failed to %s: %s %s: %s - %s",
txn_id,
destination,
method, method,
url_bytes, url_bytes,
Headers(headers_dict), type(e).__name__,
producer _flatten_response_never_received(e),
) )
return self.clock.time_bound_deferred( log_result = "%s - %s" % (
request_deferred, type(e).__name__, _flatten_response_never_received(e),
time_out=timeout / 1000. if timeout else 60,
) )
response = yield preserve_context_over_fn(send_request) if retries_left and not timeout:
if long_retries:
delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left)
delay = min(delay, 60)
delay *= random.uniform(0.8, 1.4)
else:
delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left)
delay = min(delay, 2)
delay *= random.uniform(0.8, 1.4)
log_result = "%d %s" % (response.code, response.phrase,) yield sleep(delay)
break retries_left -= 1
except Exception as e:
if not retry_on_dns_fail and isinstance(e, DNSLookupError):
logger.warn(
"DNS Lookup failed to %s with %s",
destination,
e
)
log_result = "DNS Lookup failed to %s with %s" % (
destination, e
)
raise
logger.warn(
"{%s} Sending request failed to %s: %s %s: %s - %s",
txn_id,
destination,
method,
url_bytes,
type(e).__name__,
_flatten_response_never_received(e),
)
log_result = "%s - %s" % (
type(e).__name__, _flatten_response_never_received(e),
)
if retries_left and not timeout:
if long_retries:
delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left)
delay = min(delay, 60)
delay *= random.uniform(0.8, 1.4)
else: else:
delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left) raise
delay = min(delay, 2) finally:
delay *= random.uniform(0.8, 1.4) outbound_logger.info(
"{%s} [%s] Result: %s",
txn_id,
destination,
log_result,
)
yield sleep(delay) if 200 <= response.code < 300:
retries_left -= 1 pass
else: else:
raise # :'(
finally: # Update transactions table?
outbound_logger.info( with logcontext.PreserveLoggingContext():
"{%s} [%s] Result: %s", body = yield readBody(response)
txn_id, raise HttpResponseException(
destination, response.code, response.phrase, body
log_result, )
)
if 200 <= response.code < 300: defer.returnValue(response)
pass
else:
# :'(
# Update transactions table?
body = yield preserve_context_over_fn(readBody, response)
raise HttpResponseException(
response.code, response.phrase, body
)
defer.returnValue(response)
def sign_request(self, destination, method, url_bytes, headers_dict, def sign_request(self, destination, method, url_bytes, headers_dict,
content=None): content=None):
@ -254,7 +277,9 @@ class MatrixFederationHttpClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def put_json(self, destination, path, data={}, json_data_callback=None, def put_json(self, destination, path, data={}, json_data_callback=None,
long_retries=False, timeout=None): long_retries=False, timeout=None,
ignore_backoff=False,
backoff_on_404=False):
""" Sends the specifed json data using PUT """ Sends the specifed json data using PUT
Args: Args:
@ -269,11 +294,19 @@ class MatrixFederationHttpClient(object):
retry for a short or long time. retry for a short or long time.
timeout(int): How long to try (in ms) the destination for before timeout(int): How long to try (in ms) the destination for before
giving up. None indicates no timeout. giving up. None indicates no timeout.
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
backoff_on_404 (bool): True if we should count a 404 response as
a failure of the server (and should therefore back off future
requests)
Returns: Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body. On a 4xx or 5xx error response a will be the decoded JSON body. On a 4xx or 5xx error response a
CodeMessageException is raised. CodeMessageException is raised.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
""" """
if not json_data_callback: if not json_data_callback:
@ -288,26 +321,29 @@ class MatrixFederationHttpClient(object):
producer = _JsonProducer(json_data) producer = _JsonProducer(json_data)
return producer return producer
response = yield self._create_request( response = yield self._request(
destination.encode("ascii"), destination,
"PUT", "PUT",
path.encode("ascii"), path,
body_callback=body_callback, body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]}, headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries, long_retries=long_retries,
timeout=timeout, timeout=timeout,
ignore_backoff=ignore_backoff,
backoff_on_404=backoff_on_404,
) )
if 200 <= response.code < 300: if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent? # We need to update the transactions table to say it was sent?
check_content_type_is_json(response.headers) check_content_type_is_json(response.headers)
body = yield preserve_context_over_fn(readBody, response) with logcontext.PreserveLoggingContext():
body = yield readBody(response)
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def post_json(self, destination, path, data={}, long_retries=False, def post_json(self, destination, path, data={}, long_retries=False,
timeout=None): timeout=None, ignore_backoff=False):
""" Sends the specifed json data using POST """ Sends the specifed json data using POST
Args: Args:
@ -320,11 +356,15 @@ class MatrixFederationHttpClient(object):
retry for a short or long time. retry for a short or long time.
timeout(int): How long to try (in ms) the destination for before timeout(int): How long to try (in ms) the destination for before
giving up. None indicates no timeout. giving up. None indicates no timeout.
ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway.
Returns: Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body. On a 4xx or 5xx error response a will be the decoded JSON body. On a 4xx or 5xx error response a
CodeMessageException is raised. CodeMessageException is raised.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
""" """
def body_callback(method, url_bytes, headers_dict): def body_callback(method, url_bytes, headers_dict):
@ -333,27 +373,29 @@ class MatrixFederationHttpClient(object):
) )
return _JsonProducer(data) return _JsonProducer(data)
response = yield self._create_request( response = yield self._request(
destination.encode("ascii"), destination,
"POST", "POST",
path.encode("ascii"), path,
body_callback=body_callback, body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]}, headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries, long_retries=long_retries,
timeout=timeout, timeout=timeout,
ignore_backoff=ignore_backoff,
) )
if 200 <= response.code < 300: if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent? # We need to update the transactions table to say it was sent?
check_content_type_is_json(response.headers) check_content_type_is_json(response.headers)
body = yield preserve_context_over_fn(readBody, response) with logcontext.PreserveLoggingContext():
body = yield readBody(response)
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_json(self, destination, path, args={}, retry_on_dns_fail=True, def get_json(self, destination, path, args={}, retry_on_dns_fail=True,
timeout=None): timeout=None, ignore_backoff=False):
""" GETs some json from the given host homeserver and path """ GETs some json from the given host homeserver and path
Args: Args:
@ -365,11 +407,16 @@ class MatrixFederationHttpClient(object):
timeout (int): How long to try (in ms) the destination for before timeout (int): How long to try (in ms) the destination for before
giving up. None indicates no timeout and that the request will giving up. None indicates no timeout and that the request will
be retried. be retried.
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
Returns: Returns:
Deferred: Succeeds when we get *any* HTTP response. Deferred: Succeeds when we get *any* HTTP response.
The result of the deferred is a tuple of `(code, response)`, The result of the deferred is a tuple of `(code, response)`,
where `response` is a dict representing the decoded JSON body. where `response` is a dict representing the decoded JSON body.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
""" """
logger.debug("get_json args: %s", args) logger.debug("get_json args: %s", args)
@ -386,39 +433,47 @@ class MatrixFederationHttpClient(object):
self.sign_request(destination, method, url_bytes, headers_dict) self.sign_request(destination, method, url_bytes, headers_dict)
return None return None
response = yield self._create_request( response = yield self._request(
destination.encode("ascii"), destination,
"GET", "GET",
path.encode("ascii"), path,
query_bytes=query_bytes, query_bytes=query_bytes,
body_callback=body_callback, body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail, retry_on_dns_fail=retry_on_dns_fail,
timeout=timeout, timeout=timeout,
ignore_backoff=ignore_backoff,
) )
if 200 <= response.code < 300: if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent? # We need to update the transactions table to say it was sent?
check_content_type_is_json(response.headers) check_content_type_is_json(response.headers)
body = yield preserve_context_over_fn(readBody, response) with logcontext.PreserveLoggingContext():
body = yield readBody(response)
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_file(self, destination, path, output_stream, args={}, def get_file(self, destination, path, output_stream, args={},
retry_on_dns_fail=True, max_size=None): retry_on_dns_fail=True, max_size=None,
ignore_backoff=False):
"""GETs a file from a given homeserver """GETs a file from a given homeserver
Args: Args:
destination (str): The remote server to send the HTTP request to. destination (str): The remote server to send the HTTP request to.
path (str): The HTTP path to GET. path (str): The HTTP path to GET.
output_stream (file): File to write the response body to. output_stream (file): File to write the response body to.
args (dict): Optional dictionary used to create the query string. args (dict): Optional dictionary used to create the query string.
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
Returns: Returns:
Deferred: resolves with an (int,dict) tuple of the file length and Deferred: resolves with an (int,dict) tuple of the file length and
a dict of the response headers. a dict of the response headers.
Fails with ``HTTPRequestException`` if we get an HTTP response code Fails with ``HTTPRequestException`` if we get an HTTP response code
>= 300 >= 300
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
""" """
encoded_args = {} encoded_args = {}
@ -434,22 +489,23 @@ class MatrixFederationHttpClient(object):
self.sign_request(destination, method, url_bytes, headers_dict) self.sign_request(destination, method, url_bytes, headers_dict)
return None return None
response = yield self._create_request( response = yield self._request(
destination.encode("ascii"), destination,
"GET", "GET",
path.encode("ascii"), path,
query_bytes=query_bytes, query_bytes=query_bytes,
body_callback=body_callback, body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail retry_on_dns_fail=retry_on_dns_fail,
ignore_backoff=ignore_backoff,
) )
headers = dict(response.headers.getAllRawHeaders()) headers = dict(response.headers.getAllRawHeaders())
try: try:
length = yield preserve_context_over_fn( with logcontext.PreserveLoggingContext():
_readBodyToFile, length = yield _readBodyToFile(
response, output_stream, max_size response, output_stream, max_size
) )
except: except:
logger.exception("Failed to download body") logger.exception("Failed to download body")
raise raise

View file

@ -19,6 +19,7 @@ from distutils.version import LooseVersion
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
REQUIREMENTS = { REQUIREMENTS = {
"jsonschema>=2.5.1": ["jsonschema>=2.5.1"],
"frozendict>=0.4": ["frozendict"], "frozendict>=0.4": ["frozendict"],
"unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"], "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
"canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"], "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],

View file

@ -167,7 +167,6 @@ class SlavedEventStore(BaseSlavedStore):
_get_rooms_for_user_where_membership_is_txn = ( _get_rooms_for_user_where_membership_is_txn = (
DataStore._get_rooms_for_user_where_membership_is_txn.__func__ DataStore._get_rooms_for_user_where_membership_is_txn.__func__
) )
_get_members_rows_txn = DataStore._get_members_rows_txn.__func__
_get_state_for_groups = DataStore._get_state_for_groups.__func__ _get_state_for_groups = DataStore._get_state_for_groups.__func__
_get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__ _get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__
_get_events_around_txn = DataStore._get_events_around_txn.__func__ _get_events_around_txn = DataStore._get_events_around_txn.__func__

View file

@ -57,5 +57,6 @@ class SlavedPresenceStore(BaseSlavedStore):
self.presence_stream_cache.entity_has_changed( self.presence_stream_cache.entity_has_changed(
user_id, position user_id, position
) )
self._get_presence_for_user.invalidate((user_id,))
return super(SlavedPresenceStore, self).process_replication(result) return super(SlavedPresenceStore, self).process_replication(result)

View file

@ -268,7 +268,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
if existingUid is not None: if existingUid is not None:
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body) ret = yield self.identity_handler.requestMsisdnToken(**body)
defer.returnValue((200, ret)) defer.returnValue((200, ret))

View file

@ -537,7 +537,7 @@ class RegisterRestServlet(RestServlet):
# we have nowhere to store it. # we have nowhere to store it.
device_id = synapse.api.auth.GUEST_DEVICE_ID device_id = synapse.api.auth.GUEST_DEVICE_ID
initial_display_name = params.get("initial_device_display_name") initial_display_name = params.get("initial_device_display_name")
self.device_handler.check_device_registered( yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name user_id, device_id, initial_display_name
) )

View file

@ -73,6 +73,9 @@ class LoggingTransaction(object):
def __setattr__(self, name, value): def __setattr__(self, name, value):
setattr(self.txn, name, value) setattr(self.txn, name, value)
def __iter__(self):
return self.txn.__iter__()
def execute(self, sql, *args): def execute(self, sql, *args):
self._do_execute(self.txn.execute, sql, *args) self._do_execute(self.txn.execute, sql, *args)
@ -132,7 +135,7 @@ class PerformanceCounters(object):
def interval(self, interval_duration, limit=3): def interval(self, interval_duration, limit=3):
counters = [] counters = []
for name, (count, cum_time) in self.current_counters.items(): for name, (count, cum_time) in self.current_counters.iteritems():
prev_count, prev_time = self.previous_counters.get(name, (0, 0)) prev_count, prev_time = self.previous_counters.get(name, (0, 0))
counters.append(( counters.append((
(cum_time - prev_time) / interval_duration, (cum_time - prev_time) / interval_duration,
@ -357,7 +360,7 @@ class SQLBaseStore(object):
""" """
col_headers = list(intern(column[0]) for column in cursor.description) col_headers = list(intern(column[0]) for column in cursor.description)
results = list( results = list(
dict(zip(col_headers, row)) for row in cursor.fetchall() dict(zip(col_headers, row)) for row in cursor
) )
return results return results
@ -565,7 +568,7 @@ class SQLBaseStore(object):
@staticmethod @staticmethod
def _simple_select_onecol_txn(txn, table, keyvalues, retcol): def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
if keyvalues: if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
else: else:
where = "" where = ""
@ -579,7 +582,7 @@ class SQLBaseStore(object):
txn.execute(sql, keyvalues.values()) txn.execute(sql, keyvalues.values())
return [r[0] for r in txn.fetchall()] return [r[0] for r in txn]
def _simple_select_onecol(self, table, keyvalues, retcol, def _simple_select_onecol(self, table, keyvalues, retcol,
desc="_simple_select_onecol"): desc="_simple_select_onecol"):
@ -712,7 +715,7 @@ class SQLBaseStore(object):
) )
values.extend(iterable) values.extend(iterable)
for key, value in keyvalues.items(): for key, value in keyvalues.iteritems():
clauses.append("%s = ?" % (key,)) clauses.append("%s = ?" % (key,))
values.append(value) values.append(value)
@ -753,7 +756,7 @@ class SQLBaseStore(object):
@staticmethod @staticmethod
def _simple_update_one_txn(txn, table, keyvalues, updatevalues): def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
if keyvalues: if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
else: else:
where = "" where = ""
@ -870,7 +873,7 @@ class SQLBaseStore(object):
) )
values.extend(iterable) values.extend(iterable)
for key, value in keyvalues.items(): for key, value in keyvalues.iteritems():
clauses.append("%s = ?" % (key,)) clauses.append("%s = ?" % (key,))
values.append(value) values.append(value)
@ -901,16 +904,16 @@ class SQLBaseStore(object):
txn = db_conn.cursor() txn = db_conn.cursor()
txn.execute(sql, (int(max_value),)) txn.execute(sql, (int(max_value),))
rows = txn.fetchall()
txn.close()
cache = { cache = {
row[0]: int(row[1]) row[0]: int(row[1])
for row in rows for row in txn
} }
txn.close()
if cache: if cache:
min_val = min(cache.values()) min_val = min(cache.itervalues())
else: else:
min_val = max_value min_val = max_value

View file

@ -182,7 +182,7 @@ class AccountDataStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id)) txn.execute(sql, (user_id, stream_id))
global_account_data = { global_account_data = {
row[0]: json.loads(row[1]) for row in txn.fetchall() row[0]: json.loads(row[1]) for row in txn
} }
sql = ( sql = (
@ -193,7 +193,7 @@ class AccountDataStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id)) txn.execute(sql, (user_id, stream_id))
account_data_by_room = {} account_data_by_room = {}
for row in txn.fetchall(): for row in txn:
room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data = account_data_by_room.setdefault(row[0], {})
room_account_data[row[1]] = json.loads(row[2]) room_account_data[row[1]] = json.loads(row[2])

View file

@ -178,7 +178,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
) )
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
message_json = ujson.dumps(messages_by_device["*"]) message_json = ujson.dumps(messages_by_device["*"])
for row in txn.fetchall(): for row in txn:
# Add the message for all devices for this user on this # Add the message for all devices for this user on this
# server. # server.
device = row[0] device = row[0]
@ -195,7 +195,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
# TODO: Maybe this needs to be done in batches if there are # TODO: Maybe this needs to be done in batches if there are
# too many local devices for a given user. # too many local devices for a given user.
txn.execute(sql, [user_id] + devices) txn.execute(sql, [user_id] + devices)
for row in txn.fetchall(): for row in txn:
# Only insert into the local inbox if the device exists on # Only insert into the local inbox if the device exists on
# this server # this server
device = row[0] device = row[0]
@ -251,7 +251,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
user_id, device_id, last_stream_id, current_stream_id, limit user_id, device_id, last_stream_id, current_stream_id, limit
)) ))
messages = [] messages = []
for row in txn.fetchall(): for row in txn:
stream_pos = row[0] stream_pos = row[0]
messages.append(ujson.loads(row[1])) messages.append(ujson.loads(row[1]))
if len(messages) < limit: if len(messages) < limit:
@ -340,7 +340,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
" ORDER BY stream_id ASC" " ORDER BY stream_id ASC"
) )
txn.execute(sql, (last_pos, upper_pos)) txn.execute(sql, (last_pos, upper_pos))
rows.extend(txn.fetchall()) rows.extend(txn)
return rows return rows
@ -384,7 +384,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
destination, last_stream_id, current_stream_id, limit destination, last_stream_id, current_stream_id, limit
)) ))
messages = [] messages = []
for row in txn.fetchall(): for row in txn:
stream_pos = row[0] stream_pos = row[0]
messages.append(ujson.loads(row[1])) messages.append(ujson.loads(row[1]))
if len(messages) < limit: if len(messages) < limit:

View file

@ -333,13 +333,12 @@ class DeviceStore(SQLBaseStore):
txn.execute( txn.execute(
sql, (destination, from_stream_id, now_stream_id, False) sql, (destination, from_stream_id, now_stream_id, False)
) )
rows = txn.fetchall()
if not rows:
return (now_stream_id, [])
# maps (user_id, device_id) -> stream_id # maps (user_id, device_id) -> stream_id
query_map = {(r[0], r[1]): r[2] for r in rows} query_map = {(r[0], r[1]): r[2] for r in txn}
if not query_map:
return (now_stream_id, [])
devices = self._get_e2e_device_keys_txn( devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True txn, query_map.keys(), include_all_devices=True
) )

View file

@ -153,7 +153,7 @@ class EndToEndKeyStore(SQLBaseStore):
) )
txn.execute(sql, (user_id, device_id)) txn.execute(sql, (user_id, device_id))
result = {} result = {}
for algorithm, key_count in txn.fetchall(): for algorithm, key_count in txn:
result[algorithm] = key_count result[algorithm] = key_count
return result return result
return self.runInteraction( return self.runInteraction(
@ -174,7 +174,7 @@ class EndToEndKeyStore(SQLBaseStore):
user_result = result.setdefault(user_id, {}) user_result = result.setdefault(user_id, {})
device_result = user_result.setdefault(device_id, {}) device_result = user_result.setdefault(device_id, {})
txn.execute(sql, (user_id, device_id, algorithm)) txn.execute(sql, (user_id, device_id, algorithm))
for key_id, key_json in txn.fetchall(): for key_id, key_json in txn:
device_result[algorithm + ":" + key_id] = key_json device_result[algorithm + ":" + key_id] = key_json
delete.append((user_id, device_id, algorithm, key_id)) delete.append((user_id, device_id, algorithm, key_id))
sql = ( sql = (

View file

@ -74,7 +74,7 @@ class EventFederationStore(SQLBaseStore):
base_sql % (",".join(["?"] * len(chunk)),), base_sql % (",".join(["?"] * len(chunk)),),
chunk chunk
) )
new_front.update([r[0] for r in txn.fetchall()]) new_front.update([r[0] for r in txn])
new_front -= results new_front -= results
@ -110,7 +110,7 @@ class EventFederationStore(SQLBaseStore):
txn.execute(sql, (room_id, False,)) txn.execute(sql, (room_id, False,))
return dict(txn.fetchall()) return dict(txn)
def _get_oldest_events_in_room_txn(self, txn, room_id): def _get_oldest_events_in_room_txn(self, txn, room_id):
return self._simple_select_onecol_txn( return self._simple_select_onecol_txn(
@ -152,7 +152,7 @@ class EventFederationStore(SQLBaseStore):
txn.execute(sql, (room_id, )) txn.execute(sql, (room_id, ))
results = [] results = []
for event_id, depth in txn.fetchall(): for event_id, depth in txn:
hashes = self._get_event_reference_hashes_txn(txn, event_id) hashes = self._get_event_reference_hashes_txn(txn, event_id)
prev_hashes = { prev_hashes = {
k: encode_base64(v) for k, v in hashes.items() k: encode_base64(v) for k, v in hashes.items()
@ -334,8 +334,7 @@ class EventFederationStore(SQLBaseStore):
def get_forward_extremeties_for_room_txn(txn): def get_forward_extremeties_for_room_txn(txn):
txn.execute(sql, (stream_ordering, room_id)) txn.execute(sql, (stream_ordering, room_id))
rows = txn.fetchall() return [event_id for event_id, in txn]
return [event_id for event_id, in rows]
return self.runInteraction( return self.runInteraction(
"get_forward_extremeties_for_room", "get_forward_extremeties_for_room",
@ -436,7 +435,7 @@ class EventFederationStore(SQLBaseStore):
(room_id, event_id, False, limit - len(event_results)) (room_id, event_id, False, limit - len(event_results))
) )
for row in txn.fetchall(): for row in txn:
if row[1] not in event_results: if row[1] not in event_results:
queue.put((-row[0], row[1])) queue.put((-row[0], row[1]))
@ -482,7 +481,7 @@ class EventFederationStore(SQLBaseStore):
(room_id, event_id, False, limit - len(event_results)) (room_id, event_id, False, limit - len(event_results))
) )
for e_id, in txn.fetchall(): for e_id, in txn:
new_front.add(e_id) new_front.add(e_id)
new_front -= earliest_events new_front -= earliest_events

View file

@ -208,7 +208,7 @@ class EventPushActionsStore(SQLBaseStore):
" stream_ordering >= ? AND stream_ordering <= ?" " stream_ordering >= ? AND stream_ordering <= ?"
) )
txn.execute(sql, (min_stream_ordering, max_stream_ordering)) txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn.fetchall()] return [r[0] for r in txn]
ret = yield self.runInteraction("get_push_action_users_in_range", f) ret = yield self.runInteraction("get_push_action_users_in_range", f)
defer.returnValue(ret) defer.returnValue(ret)

View file

@ -217,14 +217,14 @@ class EventsStore(SQLBaseStore):
partitioned.setdefault(event.room_id, []).append((event, ctx)) partitioned.setdefault(event.room_id, []).append((event, ctx))
deferreds = [] deferreds = []
for room_id, evs_ctxs in partitioned.items(): for room_id, evs_ctxs in partitioned.iteritems():
d = preserve_fn(self._event_persist_queue.add_to_queue)( d = preserve_fn(self._event_persist_queue.add_to_queue)(
room_id, evs_ctxs, room_id, evs_ctxs,
backfilled=backfilled, backfilled=backfilled,
) )
deferreds.append(d) deferreds.append(d)
for room_id in partitioned.keys(): for room_id in partitioned:
self._maybe_start_persisting(room_id) self._maybe_start_persisting(room_id)
return preserve_context_over_deferred( return preserve_context_over_deferred(
@ -323,7 +323,7 @@ class EventsStore(SQLBaseStore):
(event, context) (event, context)
) )
for room_id, ev_ctx_rm in events_by_room.items(): for room_id, ev_ctx_rm in events_by_room.iteritems():
# Work out new extremities by recursively adding and removing # Work out new extremities by recursively adding and removing
# the new events. # the new events.
latest_event_ids = yield self.get_latest_event_ids_in_room( latest_event_ids = yield self.get_latest_event_ids_in_room(
@ -453,10 +453,10 @@ class EventsStore(SQLBaseStore):
missing_event_ids, missing_event_ids,
) )
groups = set(event_to_groups.values()) groups = set(event_to_groups.itervalues())
group_to_state = yield self._get_state_for_groups(groups) group_to_state = yield self._get_state_for_groups(groups)
state_sets.extend(group_to_state.values()) state_sets.extend(group_to_state.itervalues())
if not new_latest_event_ids: if not new_latest_event_ids:
current_state = {} current_state = {}
@ -718,7 +718,7 @@ class EventsStore(SQLBaseStore):
def _update_forward_extremities_txn(self, txn, new_forward_extremities, def _update_forward_extremities_txn(self, txn, new_forward_extremities,
max_stream_order): max_stream_order):
for room_id, new_extrem in new_forward_extremities.items(): for room_id, new_extrem in new_forward_extremities.iteritems():
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
table="event_forward_extremities", table="event_forward_extremities",
@ -736,7 +736,7 @@ class EventsStore(SQLBaseStore):
"event_id": ev_id, "event_id": ev_id,
"room_id": room_id, "room_id": room_id,
} }
for room_id, new_extrem in new_forward_extremities.items() for room_id, new_extrem in new_forward_extremities.iteritems()
for ev_id in new_extrem for ev_id in new_extrem
], ],
) )
@ -753,7 +753,7 @@ class EventsStore(SQLBaseStore):
"event_id": event_id, "event_id": event_id,
"stream_ordering": max_stream_order, "stream_ordering": max_stream_order,
} }
for room_id, new_extrem in new_forward_extremities.items() for room_id, new_extrem in new_forward_extremities.iteritems()
for event_id in new_extrem for event_id in new_extrem
] ]
) )
@ -807,7 +807,7 @@ class EventsStore(SQLBaseStore):
event.depth, depth_updates.get(event.room_id, event.depth) event.depth, depth_updates.get(event.room_id, event.depth)
) )
for room_id, depth in depth_updates.items(): for room_id, depth in depth_updates.iteritems():
self._update_min_depth_for_room_txn(txn, room_id, depth) self._update_min_depth_for_room_txn(txn, room_id, depth)
def _update_outliers_txn(self, txn, events_and_contexts): def _update_outliers_txn(self, txn, events_and_contexts):
@ -834,7 +834,7 @@ class EventsStore(SQLBaseStore):
have_persisted = { have_persisted = {
event_id: outlier event_id: outlier
for event_id, outlier in txn.fetchall() for event_id, outlier in txn
} }
to_remove = set() to_remove = set()
@ -957,14 +957,10 @@ class EventsStore(SQLBaseStore):
return return
def event_dict(event): def event_dict(event):
return { d = event.get_dict()
k: v d.pop("redacted", None)
for k, v in event.get_dict().items() d.pop("redacted_because", None)
if k not in [ return d
"redacted",
"redacted_because",
]
}
self._simple_insert_many_txn( self._simple_insert_many_txn(
txn, txn,
@ -1997,7 +1993,7 @@ class EventsStore(SQLBaseStore):
"state_key": key[1], "state_key": key[1],
"event_id": state_id, "event_id": state_id,
} }
for key, state_id in curr_state.items() for key, state_id in curr_state.iteritems()
], ],
) )

View file

@ -101,9 +101,10 @@ class KeyStore(SQLBaseStore):
key_ids key_ids
Args: Args:
server_name (str): The name of the server. server_name (str): The name of the server.
key_ids (list of str): List of key_ids to try and look up. key_ids (iterable[str]): key_ids to try and look up.
Returns: Returns:
(list of VerifyKey): The verification keys. Deferred: resolves to dict[str, VerifyKey]: map from
key_id to verification key.
""" """
keys = {} keys = {}
for key_id in key_ids: for key_id in key_ids:

View file

@ -356,7 +356,7 @@ def _get_or_create_schema_state(txn, database_engine):
), ),
(current_version,) (current_version,)
) )
applied_deltas = [d for d, in txn.fetchall()] applied_deltas = [d for d, in txn]
return current_version, applied_deltas, upgraded return current_version, applied_deltas, upgraded
return None return None

View file

@ -85,8 +85,8 @@ class PresenceStore(SQLBaseStore):
self.presence_stream_cache.entity_has_changed, self.presence_stream_cache.entity_has_changed,
state.user_id, stream_id, state.user_id, stream_id,
) )
self._invalidate_cache_and_stream( txn.call_after(
txn, self._get_presence_for_user, (state.user_id,) self._get_presence_for_user.invalidate, (state.user_id,)
) )
# Actually insert new rows # Actually insert new rows

View file

@ -313,10 +313,9 @@ class ReceiptsStore(SQLBaseStore):
) )
txn.execute(sql, (room_id, receipt_type, user_id)) txn.execute(sql, (room_id, receipt_type, user_id))
results = txn.fetchall()
if results and topological_ordering: if topological_ordering:
for to, so, _ in results: for to, so, _ in txn:
if int(to) > topological_ordering: if int(to) > topological_ordering:
return False return False
elif int(to) == topological_ordering and int(so) >= stream_ordering: elif int(to) == topological_ordering and int(so) >= stream_ordering:

View file

@ -209,7 +209,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
" WHERE lower(name) = lower(?)" " WHERE lower(name) = lower(?)"
) )
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
return dict(txn.fetchall()) return dict(txn)
return self.runInteraction("get_users_by_id_case_insensitive", f) return self.runInteraction("get_users_by_id_case_insensitive", f)

View file

@ -396,7 +396,7 @@ class RoomStore(SQLBaseStore):
sql % ("AND appservice_id IS NULL",), sql % ("AND appservice_id IS NULL",),
(stream_id,) (stream_id,)
) )
return dict(txn.fetchall()) return dict(txn)
else: else:
# We want to get from all lists, so we need to aggregate the results # We want to get from all lists, so we need to aggregate the results
@ -422,7 +422,7 @@ class RoomStore(SQLBaseStore):
results = {} results = {}
# A room is visible if its visible on any list. # A room is visible if its visible on any list.
for room_id, visibility in txn.fetchall(): for room_id, visibility in txn:
results[room_id] = bool(visibility) or results.get(room_id, False) results[room_id] = bool(visibility) or results.get(room_id, False)
return results return results

View file

@ -129,17 +129,30 @@ class RoomMemberStore(SQLBaseStore):
with self._stream_id_gen.get_next() as stream_ordering: with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering) yield self.runInteraction("locally_reject_invite", f, stream_ordering)
@cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
def get_hosts_in_room(self, room_id, cache_context):
"""Returns the set of all hosts currently in the room
"""
user_ids = yield self.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate,
)
hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
defer.returnValue(hosts)
@cached(max_entries=500000, iterable=True) @cached(max_entries=500000, iterable=True)
def get_users_in_room(self, room_id): def get_users_in_room(self, room_id):
def f(txn): def f(txn):
sql = (
rows = self._get_members_rows_txn( "SELECT m.user_id FROM room_memberships as m"
txn, " INNER JOIN current_state_events as c"
room_id=room_id, " ON m.event_id = c.event_id "
membership=Membership.JOIN, " AND m.room_id = c.room_id "
" AND m.user_id = c.state_key"
" WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
) )
return [r["user_id"] for r in rows] txn.execute(sql, (room_id, Membership.JOIN,))
return [r[0] for r in txn]
return self.runInteraction("get_users_in_room", f) return self.runInteraction("get_users_in_room", f)
@cached() @cached()
@ -246,34 +259,6 @@ class RoomMemberStore(SQLBaseStore):
return results return results
def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
where_clause = "c.room_id = ?"
where_values = [room_id]
if membership:
where_clause += " AND m.membership = ?"
where_values.append(membership)
if user_id:
where_clause += " AND m.user_id = ?"
where_values.append(user_id)
sql = (
"SELECT m.* FROM room_memberships as m"
" INNER JOIN current_state_events as c"
" ON m.event_id = c.event_id "
" AND m.room_id = c.room_id "
" AND m.user_id = c.state_key"
" WHERE c.type = 'm.room.member' AND %(where)s"
) % {
"where": where_clause,
}
txn.execute(sql, where_values)
rows = self.cursor_to_dict(txn)
return rows
@cachedInlineCallbacks(max_entries=500000, iterable=True) @cachedInlineCallbacks(max_entries=500000, iterable=True)
def get_rooms_for_user(self, user_id): def get_rooms_for_user(self, user_id):
"""Returns a set of room_ids the user is currently joined to """Returns a set of room_ids the user is currently joined to

View file

@ -72,7 +72,7 @@ class SignatureStore(SQLBaseStore):
" WHERE event_id = ?" " WHERE event_id = ?"
) )
txn.execute(query, (event_id, )) txn.execute(query, (event_id, ))
return {k: v for k, v in txn.fetchall()} return {k: v for k, v in txn}
def _store_event_reference_hashes_txn(self, txn, events): def _store_event_reference_hashes_txn(self, txn, events):
"""Store a hash for a PDU """Store a hash for a PDU

View file

@ -90,7 +90,7 @@ class StateStore(SQLBaseStore):
event_ids, event_ids,
) )
groups = set(event_to_groups.values()) groups = set(event_to_groups.itervalues())
group_to_state = yield self._get_state_for_groups(groups) group_to_state = yield self._get_state_for_groups(groups)
defer.returnValue(group_to_state) defer.returnValue(group_to_state)
@ -108,17 +108,18 @@ class StateStore(SQLBaseStore):
state_event_map = yield self.get_events( state_event_map = yield self.get_events(
[ [
ev_id for group_ids in group_to_ids.values() ev_id for group_ids in group_to_ids.itervalues()
for ev_id in group_ids.values() for ev_id in group_ids.itervalues()
], ],
get_prev_content=False get_prev_content=False
) )
defer.returnValue({ defer.returnValue({
group: [ group: [
state_event_map[v] for v in event_id_map.values() if v in state_event_map state_event_map[v] for v in event_id_map.itervalues()
if v in state_event_map
] ]
for group, event_id_map in group_to_ids.items() for group, event_id_map in group_to_ids.iteritems()
}) })
def _have_persisted_state_group_txn(self, txn, state_group): def _have_persisted_state_group_txn(self, txn, state_group):
@ -190,7 +191,7 @@ class StateStore(SQLBaseStore):
"state_key": key[1], "state_key": key[1],
"event_id": state_id, "event_id": state_id,
} }
for key, state_id in context.delta_ids.items() for key, state_id in context.delta_ids.iteritems()
], ],
) )
else: else:
@ -205,7 +206,7 @@ class StateStore(SQLBaseStore):
"state_key": key[1], "state_key": key[1],
"event_id": state_id, "event_id": state_id,
} }
for key, state_id in context.current_state_ids.items() for key, state_id in context.current_state_ids.iteritems()
], ],
) )
@ -217,7 +218,7 @@ class StateStore(SQLBaseStore):
"state_group": state_group_id, "state_group": state_group_id,
"event_id": event_id, "event_id": event_id,
} }
for event_id, state_group_id in state_groups.items() for event_id, state_group_id in state_groups.iteritems()
], ],
) )
@ -341,10 +342,10 @@ class StateStore(SQLBaseStore):
args.extend(where_args) args.extend(where_args)
txn.execute(sql % (where_clause,), args) txn.execute(sql % (where_clause,), args)
rows = self.cursor_to_dict(txn) for row in txn:
for row in rows: typ, state_key, event_id = row
key = (row["type"], row["state_key"]) key = (typ, state_key)
results[group][key] = row["event_id"] results[group][key] = event_id
else: else:
if types is not None: if types is not None:
where_clause = "AND (%s)" % ( where_clause = "AND (%s)" % (
@ -373,12 +374,11 @@ class StateStore(SQLBaseStore):
" WHERE state_group = ? %s" % (where_clause,), " WHERE state_group = ? %s" % (where_clause,),
args args
) )
rows = txn.fetchall() results[group].update(
results[group].update({ ((typ, state_key), event_id)
(typ, state_key): event_id for typ, state_key, event_id in txn
for typ, state_key, event_id in rows
if (typ, state_key) not in results[group] if (typ, state_key) not in results[group]
}) )
# If the lengths match then we must have all the types, # If the lengths match then we must have all the types,
# so no need to go walk further down the tree. # so no need to go walk further down the tree.
@ -415,21 +415,21 @@ class StateStore(SQLBaseStore):
event_ids, event_ids,
) )
groups = set(event_to_groups.values()) groups = set(event_to_groups.itervalues())
group_to_state = yield self._get_state_for_groups(groups, types) group_to_state = yield self._get_state_for_groups(groups, types)
state_event_map = yield self.get_events( state_event_map = yield self.get_events(
[ev_id for sd in group_to_state.values() for ev_id in sd.values()], [ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()],
get_prev_content=False get_prev_content=False
) )
event_to_state = { event_to_state = {
event_id: { event_id: {
k: state_event_map[v] k: state_event_map[v]
for k, v in group_to_state[group].items() for k, v in group_to_state[group].iteritems()
if v in state_event_map if v in state_event_map
} }
for event_id, group in event_to_groups.items() for event_id, group in event_to_groups.iteritems()
} }
defer.returnValue({event: event_to_state[event] for event in event_ids}) defer.returnValue({event: event_to_state[event] for event in event_ids})
@ -452,12 +452,12 @@ class StateStore(SQLBaseStore):
event_ids, event_ids,
) )
groups = set(event_to_groups.values()) groups = set(event_to_groups.itervalues())
group_to_state = yield self._get_state_for_groups(groups, types) group_to_state = yield self._get_state_for_groups(groups, types)
event_to_state = { event_to_state = {
event_id: group_to_state[group] event_id: group_to_state[group]
for event_id, group in event_to_groups.items() for event_id, group in event_to_groups.iteritems()
} }
defer.returnValue({event: event_to_state[event] for event in event_ids}) defer.returnValue({event: event_to_state[event] for event in event_ids})
@ -569,7 +569,7 @@ class StateStore(SQLBaseStore):
got_all = not (missing_types or types is None) got_all = not (missing_types or types is None)
return { return {
k: v for k, v in state_dict_ids.items() k: v for k, v in state_dict_ids.iteritems()
if include(k[0], k[1]) if include(k[0], k[1])
}, missing_types, got_all }, missing_types, got_all
@ -628,7 +628,7 @@ class StateStore(SQLBaseStore):
# Now we want to update the cache with all the things we fetched # Now we want to update the cache with all the things we fetched
# from the database. # from the database.
for group, group_state_dict in group_to_state_dict.items(): for group, group_state_dict in group_to_state_dict.iteritems():
if types: if types:
# We delibrately put key -> None mappings into the cache to # We delibrately put key -> None mappings into the cache to
# cache absence of the key, on the assumption that if we've # cache absence of the key, on the assumption that if we've
@ -643,10 +643,10 @@ class StateStore(SQLBaseStore):
else: else:
state_dict = results[group] state_dict = results[group]
state_dict.update({ state_dict.update(
(intern_string(k[0]), intern_string(k[1])): v ((intern_string(k[0]), intern_string(k[1])), v)
for k, v in group_state_dict.items() for k, v in group_state_dict.iteritems()
}) )
self._state_group_cache.update( self._state_group_cache.update(
cache_seq_num, cache_seq_num,
@ -657,10 +657,10 @@ class StateStore(SQLBaseStore):
# Remove all the entries with None values. The None values were just # Remove all the entries with None values. The None values were just
# used for bookkeeping in the cache. # used for bookkeeping in the cache.
for group, state_dict in results.items(): for group, state_dict in results.iteritems():
results[group] = { results[group] = {
key: event_id key: event_id
for key, event_id in state_dict.items() for key, event_id in state_dict.iteritems()
if event_id if event_id
} }
@ -749,7 +749,7 @@ class StateStore(SQLBaseStore):
# of keys # of keys
delta_state = { delta_state = {
key: value for key, value in curr_state.items() key: value for key, value in curr_state.iteritems()
if prev_state.get(key, None) != value if prev_state.get(key, None) != value
} }
@ -789,7 +789,7 @@ class StateStore(SQLBaseStore):
"state_key": key[1], "state_key": key[1],
"event_id": state_id, "event_id": state_id,
} }
for key, state_id in delta_state.items() for key, state_id in delta_state.iteritems()
], ],
) )

View file

@ -95,7 +95,7 @@ class TagsStore(SQLBaseStore):
for stream_id, user_id, room_id in tag_ids: for stream_id, user_id, room_id in tag_ids:
txn.execute(sql, (user_id, room_id)) txn.execute(sql, (user_id, room_id))
tags = [] tags = []
for tag, content in txn.fetchall(): for tag, content in txn:
tags.append(json.dumps(tag) + ":" + content) tags.append(json.dumps(tag) + ":" + content)
tag_json = "{" + ",".join(tags) + "}" tag_json = "{" + ",".join(tags) + "}"
results.append((stream_id, user_id, room_id, tag_json)) results.append((stream_id, user_id, room_id, tag_json))
@ -132,7 +132,7 @@ class TagsStore(SQLBaseStore):
" WHERE user_id = ? AND stream_id > ?" " WHERE user_id = ? AND stream_id > ?"
) )
txn.execute(sql, (user_id, stream_id)) txn.execute(sql, (user_id, stream_id))
room_ids = [row[0] for row in txn.fetchall()] room_ids = [row[0] for row in txn]
return room_ids return room_ids
changed = self._account_data_stream_cache.has_entity_changed( changed = self._account_data_stream_cache.has_entity_changed(

View file

@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
class DeferredTimedOutError(SynapseError): class DeferredTimedOutError(SynapseError):
def __init__(self): def __init__(self):
super(SynapseError).__init__(504, "Timed out") super(SynapseError, self).__init__(504, "Timed out")
def unwrapFirstError(failure): def unwrapFirstError(failure):
@ -93,8 +93,10 @@ class Clock(object):
ret_deferred = defer.Deferred() ret_deferred = defer.Deferred()
def timed_out_fn(): def timed_out_fn():
e = DeferredTimedOutError()
try: try:
ret_deferred.errback(DeferredTimedOutError()) ret_deferred.errback(e)
except: except:
pass pass
@ -114,7 +116,7 @@ class Clock(object):
ret_deferred.addBoth(cancel) ret_deferred.addBoth(cancel)
def sucess(res): def success(res):
try: try:
ret_deferred.callback(res) ret_deferred.callback(res)
except: except:
@ -128,7 +130,7 @@ class Clock(object):
except: except:
pass pass
given_deferred.addCallbacks(callback=sucess, errback=err) given_deferred.addCallbacks(callback=success, errback=err)
timer = self.call_later(time_out, timed_out_fn) timer = self.call_later(time_out, timed_out_fn)

View file

@ -189,7 +189,55 @@ class Cache(object):
self.cache.clear() self.cache.clear()
class CacheDescriptor(object): class _CacheDescriptorBase(object):
def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
self.orig = orig
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
arg_spec = inspect.getargspec(orig)
all_args = arg_spec.args
if "cache_context" in all_args:
if not cache_context:
raise ValueError(
"Cannot have a 'cache_context' arg without setting"
" cache_context=True"
)
elif cache_context:
raise ValueError(
"Cannot have cache_context=True without having an arg"
" named `cache_context`"
)
if num_args is None:
num_args = len(all_args) - 1
if cache_context:
num_args -= 1
if len(all_args) < num_args + 1:
raise Exception(
"Not enough explicit positional arguments to key off for %r: "
"got %i args, but wanted %i. (@cached cannot key off *args or "
"**kwargs)"
% (orig.__name__, len(all_args), num_args)
)
self.num_args = num_args
self.arg_names = all_args[1:num_args + 1]
if "cache_context" in self.arg_names:
raise Exception(
"cache_context arg cannot be included among the cache keys"
)
self.add_cache_context = cache_context
class CacheDescriptor(_CacheDescriptorBase):
""" A method decorator that applies a memoizing cache around the function. """ A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that This caches deferreds, rather than the results themselves. Deferreds that
@ -217,52 +265,24 @@ class CacheDescriptor(object):
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate) r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
defer.returnValue(r1 + r2) defer.returnValue(r1 + r2)
Args:
num_args (int): number of positional arguments (excluding ``self`` and
``cache_context``) to use as cache keys. Defaults to all named
args of the function.
""" """
def __init__(self, orig, max_entries=1000, num_args=1, tree=False, def __init__(self, orig, max_entries=1000, num_args=None, tree=False,
inlineCallbacks=False, cache_context=False, iterable=False): inlineCallbacks=False, cache_context=False, iterable=False):
super(CacheDescriptor, self).__init__(
orig, num_args=num_args, inlineCallbacks=inlineCallbacks,
cache_context=cache_context)
max_entries = int(max_entries * CACHE_SIZE_FACTOR) max_entries = int(max_entries * CACHE_SIZE_FACTOR)
self.orig = orig
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
self.max_entries = max_entries self.max_entries = max_entries
self.num_args = num_args
self.tree = tree self.tree = tree
self.iterable = iterable self.iterable = iterable
all_args = inspect.getargspec(orig)
self.arg_names = all_args.args[1:num_args + 1]
if "cache_context" in all_args.args:
if not cache_context:
raise ValueError(
"Cannot have a 'cache_context' arg without setting"
" cache_context=True"
)
try:
self.arg_names.remove("cache_context")
except ValueError:
pass
elif cache_context:
raise ValueError(
"Cannot have cache_context=True without having an arg"
" named `cache_context`"
)
self.add_cache_context = cache_context
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwargs)"
% (orig.__name__,)
)
def __get__(self, obj, objtype=None): def __get__(self, obj, objtype=None):
cache = Cache( cache = Cache(
name=self.orig.__name__, name=self.orig.__name__,
@ -338,48 +358,36 @@ class CacheDescriptor(object):
return wrapped return wrapped
class CacheListDescriptor(object): class CacheListDescriptor(_CacheDescriptorBase):
"""Wraps an existing cache to support bulk fetching of keys. """Wraps an existing cache to support bulk fetching of keys.
Given a list of keys it looks in the cache to find any hits, then passes Given a list of keys it looks in the cache to find any hits, then passes
the list of missing keys to the wrapped fucntion. the list of missing keys to the wrapped fucntion.
""" """
def __init__(self, orig, cached_method_name, list_name, num_args=1, def __init__(self, orig, cached_method_name, list_name, num_args=None,
inlineCallbacks=False): inlineCallbacks=False):
""" """
Args: Args:
orig (function) orig (function)
method_name (str); The name of the chached method. cached_method_name (str): The name of the chached method.
list_name (str): Name of the argument which is the bulk lookup list list_name (str): Name of the argument which is the bulk lookup list
num_args (int) num_args (int): number of positional arguments (excluding ``self``,
but including list_name) to use as cache keys. Defaults to all
named args of the function.
inlineCallbacks (bool): Whether orig is a generator that should inlineCallbacks (bool): Whether orig is a generator that should
be wrapped by defer.inlineCallbacks be wrapped by defer.inlineCallbacks
""" """
self.orig = orig super(CacheListDescriptor, self).__init__(
orig, num_args=num_args, inlineCallbacks=inlineCallbacks)
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
self.num_args = num_args
self.list_name = list_name self.list_name = list_name
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
self.list_pos = self.arg_names.index(self.list_name) self.list_pos = self.arg_names.index(self.list_name)
self.cached_method_name = cached_method_name self.cached_method_name = cached_method_name
self.sentinel = object() self.sentinel = object()
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)"
% (orig.__name__,)
)
if self.list_name not in self.arg_names: if self.list_name not in self.arg_names:
raise Exception( raise Exception(
"Couldn't see arguments %r for %r." "Couldn't see arguments %r for %r."
@ -487,7 +495,7 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
self.cache.invalidate(self.key) self.cache.invalidate(self.key)
def cached(max_entries=1000, num_args=1, tree=False, cache_context=False, def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
iterable=False): iterable=False):
return lambda orig: CacheDescriptor( return lambda orig: CacheDescriptor(
orig, orig,
@ -499,8 +507,8 @@ def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
) )
def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False, def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False,
iterable=False): cache_context=False, iterable=False):
return lambda orig: CacheDescriptor( return lambda orig: CacheDescriptor(
orig, orig,
max_entries=max_entries, max_entries=max_entries,
@ -512,7 +520,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
) )
def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False): def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False):
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`. """Creates a descriptor that wraps a function in a `CacheListDescriptor`.
Used to do batch lookups for an already created cache. A single argument Used to do batch lookups for an already created cache. A single argument
@ -525,7 +533,8 @@ def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False)
cache (Cache): The underlying cache to use. cache (Cache): The underlying cache to use.
list_name (str): The name of the argument that is the list to use to list_name (str): The name of the argument that is the list to use to
do batch lookups in the cache. do batch lookups in the cache.
num_args (int): Number of arguments to use as the key in the cache. num_args (int): Number of arguments to use as the key in the cache
(including list_name). Defaults to all named parameters.
inlineCallbacks (bool): Should the function be wrapped in an inlineCallbacks (bool): Should the function be wrapped in an
`defer.inlineCallbacks`? `defer.inlineCallbacks`?

View file

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import synapse.util.logcontext
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
@ -35,7 +35,8 @@ class NotRetryingDestination(Exception):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_retry_limiter(destination, clock, store, **kwargs): def get_retry_limiter(destination, clock, store, ignore_backoff=False,
**kwargs):
"""For a given destination check if we have previously failed to """For a given destination check if we have previously failed to
send a request there and are waiting before retrying the destination. send a request there and are waiting before retrying the destination.
If we are not ready to retry the destination, this will raise a If we are not ready to retry the destination, this will raise a
@ -43,6 +44,14 @@ def get_retry_limiter(destination, clock, store, **kwargs):
that will mark the destination as down if an exception is thrown (excluding that will mark the destination as down if an exception is thrown (excluding
CodeMessageException with code < 500) CodeMessageException with code < 500)
Args:
destination (str): name of homeserver
clock (synapse.util.clock): timing source
store (synapse.storage.transactions.TransactionStore): datastore
ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway. We will still update the next
retry_interval on success/failure.
Example usage: Example usage:
try: try:
@ -66,7 +75,7 @@ def get_retry_limiter(destination, clock, store, **kwargs):
now = int(clock.time_msec()) now = int(clock.time_msec())
if retry_last_ts + retry_interval > now: if not ignore_backoff and retry_last_ts + retry_interval > now:
raise NotRetryingDestination( raise NotRetryingDestination(
retry_last_ts=retry_last_ts, retry_last_ts=retry_last_ts,
retry_interval=retry_interval, retry_interval=retry_interval,
@ -124,7 +133,13 @@ class RetryDestinationLimiter(object):
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
valid_err_code = False valid_err_code = False
if exc_type is not None and issubclass(exc_type, CodeMessageException): if exc_type is None:
valid_err_code = True
elif not issubclass(exc_type, Exception):
# avoid treating exceptions which don't derive from Exception as
# failures; this is mostly so as not to catch defer._DefGen.
valid_err_code = True
elif issubclass(exc_type, CodeMessageException):
# Some error codes are perfectly fine for some APIs, whereas other # Some error codes are perfectly fine for some APIs, whereas other
# APIs may expect to never received e.g. a 404. It's important to # APIs may expect to never received e.g. a 404. It's important to
# handle 404 as some remote servers will return a 404 when the HS # handle 404 as some remote servers will return a 404 when the HS
@ -142,11 +157,13 @@ class RetryDestinationLimiter(object):
else: else:
valid_err_code = False valid_err_code = False
if exc_type is None or valid_err_code: if valid_err_code:
# We connected successfully. # We connected successfully.
if not self.retry_interval: if not self.retry_interval:
return return
logger.debug("Connection to %s was successful; clearing backoff",
self.destination)
retry_last_ts = 0 retry_last_ts = 0
self.retry_interval = 0 self.retry_interval = 0
else: else:
@ -160,6 +177,10 @@ class RetryDestinationLimiter(object):
else: else:
self.retry_interval = self.min_retry_interval self.retry_interval = self.min_retry_interval
logger.debug(
"Connection to %s was unsuccessful (%s(%s)); backoff now %i",
self.destination, exc_type, exc_val, self.retry_interval
)
retry_last_ts = int(self.clock.time_msec()) retry_last_ts = int(self.clock.time_msec())
@defer.inlineCallbacks @defer.inlineCallbacks
@ -173,4 +194,5 @@ class RetryDestinationLimiter(object):
"Failed to store set_destination_retry_timings", "Failed to store set_destination_retry_timings",
) )
store_retry_timings() # we deliberately do this in the background.
synapse.util.logcontext.preserve_fn(store_retry_timings)()

View file

@ -134,6 +134,13 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
if prev_membership not in MEMBERSHIP_PRIORITY: if prev_membership not in MEMBERSHIP_PRIORITY:
prev_membership = "leave" prev_membership = "leave"
# Always allow the user to see their own leave events, otherwise
# they won't see the room disappear if they reject the invite
if membership == "leave" and (
prev_membership == "join" or prev_membership == "invite"
):
return True
new_priority = MEMBERSHIP_PRIORITY.index(membership) new_priority = MEMBERSHIP_PRIORITY.index(membership)
old_priority = MEMBERSHIP_PRIORITY.index(prev_membership) old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
if old_priority < new_priority: if old_priority < new_priority:

View file

@ -23,6 +23,9 @@ from tests.utils import (
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.api.errors import SynapseError
import jsonschema
user_localpart = "test_user" user_localpart = "test_user"
@ -54,6 +57,70 @@ class FilteringTestCase(unittest.TestCase):
self.datastore = hs.get_datastore() self.datastore = hs.get_datastore()
def test_errors_on_invalid_filters(self):
invalid_filters = [
{"boom": {}},
{"account_data": "Hello World"},
{"event_fields": ["\\foo"]},
{"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}},
{"event_format": "other"},
{"room": {"not_rooms": ["#foo:pik-test"]}},
{"presence": {"senders": ["@bar;pik.test.com"]}}
]
for filter in invalid_filters:
with self.assertRaises(SynapseError) as check_filter_error:
self.filtering.check_valid_filter(filter)
self.assertIsInstance(check_filter_error.exception, SynapseError)
def test_valid_filters(self):
valid_filters = [
{
"room": {
"timeline": {"limit": 20},
"state": {"not_types": ["m.room.member"]},
"ephemeral": {"limit": 0, "not_types": ["*"]},
"include_leave": False,
"rooms": ["!dee:pik-test"],
"not_rooms": ["!gee:pik-test"],
"account_data": {"limit": 0, "types": ["*"]}
}
},
{
"room": {
"state": {
"types": ["m.room.*"],
"not_rooms": ["!726s6s6q:example.com"]
},
"timeline": {
"limit": 10,
"types": ["m.room.message"],
"not_rooms": ["!726s6s6q:example.com"],
"not_senders": ["@spam:example.com"]
},
"ephemeral": {
"types": ["m.receipt", "m.typing"],
"not_rooms": ["!726s6s6q:example.com"],
"not_senders": ["@spam:example.com"]
}
},
"presence": {
"types": ["m.presence"],
"not_senders": ["@alice:example.com"]
},
"event_format": "client",
"event_fields": ["type", "content", "sender"]
}
]
for filter in valid_filters:
try:
self.filtering.check_valid_filter(filter)
except jsonschema.ValidationError as e:
self.fail(e)
def test_limits_are_applied(self):
# TODO
pass
def test_definition_types_works_with_literals(self): def test_definition_types_works_with_literals(self):
definition = { definition = {
"types": ["m.room.message", "org.matrix.foo.bar"] "types": ["m.room.message", "org.matrix.foo.bar"]

View file

@ -93,6 +93,7 @@ class DirectoryTestCase(unittest.TestCase):
"room_alias": "#another:remote", "room_alias": "#another:remote",
}, },
retry_on_dns_fail=False, retry_on_dns_fail=False,
ignore_backoff=True,
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -119,7 +119,8 @@ class ProfileTestCase(unittest.TestCase):
self.mock_federation.make_query.assert_called_with( self.mock_federation.make_query.assert_called_with(
destination="remote", destination="remote",
query_type="profile", query_type="profile",
args={"user_id": "@alice:remote", "field": "displayname"} args={"user_id": "@alice:remote", "field": "displayname"},
ignore_backoff=True,
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -192,6 +192,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
), ),
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True, long_retries=True,
backoff_on_404=True,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -263,6 +264,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
), ),
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True, long_retries=True,
backoff_on_404=True,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )

View file

@ -33,8 +33,8 @@ PATH_PREFIX = "/_matrix/client/v2_alpha"
class FilterTestCase(unittest.TestCase): class FilterTestCase(unittest.TestCase):
USER_ID = "@apple:test" USER_ID = "@apple:test"
EXAMPLE_FILTER = {"type": ["m.*"]} EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
EXAMPLE_FILTER_JSON = '{"type": ["m.*"]}' EXAMPLE_FILTER_JSON = '{"room": {"timeline": {"types": ["m.room.message"]}}}'
TO_REGISTER = [filter] TO_REGISTER = [filter]
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -89,7 +89,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_select_one_1col(self): def test_select_one_1col(self):
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
self.mock_txn.fetchall.return_value = [("Value",)] self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
value = yield self.datastore._simple_select_one_onecol( value = yield self.datastore._simple_select_one_onecol(
table="tablename", table="tablename",
@ -136,7 +136,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_select_list(self): def test_select_list(self):
self.mock_txn.rowcount = 3 self.mock_txn.rowcount = 3
self.mock_txn.fetchall.return_value = ((1,), (2,), (3,)) self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
self.mock_txn.description = ( self.mock_txn.description = (
("colA", None, None, None, None, None, None), ("colA", None, None, None, None, None, None),
) )

View file

@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import signedjson.key
from twisted.internet import defer
import tests.unittest
import tests.utils
class KeyStoreTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
super(KeyStoreTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.keys.KeyStore
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver()
self.store = hs.get_datastore()
@defer.inlineCallbacks
def test_get_server_verify_keys(self):
key1 = signedjson.key.decode_verify_key_base64(
"ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
)
key2 = signedjson.key.decode_verify_key_base64(
"ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
)
yield self.store.store_server_verify_key(
"server1", "from_server", 0, key1
)
yield self.store.store_server_verify_key(
"server1", "from_server", 0, key2
)
res = yield self.store.get_server_verify_keys(
"server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"])
self.assertEqual(len(res.keys()), 2)
self.assertEqual(res["ed25519:key1"].version, "key1")
self.assertEqual(res["ed25519:key2"].version, "key2")

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,86 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import mock
from twisted.internet import defer
from synapse.util.caches import descriptors
from tests import unittest
class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cache(self):
class Cls(object):
def __init__(self):
self.mock = mock.Mock()
@descriptors.cached()
def fn(self, arg1, arg2):
return self.mock(arg1, arg2)
obj = Cls()
obj.mock.return_value = 'fish'
r = yield obj.fn(1, 2)
self.assertEqual(r, 'fish')
obj.mock.assert_called_once_with(1, 2)
obj.mock.reset_mock()
# a call with different params should call the mock again
obj.mock.return_value = 'chips'
r = yield obj.fn(1, 3)
self.assertEqual(r, 'chips')
obj.mock.assert_called_once_with(1, 3)
obj.mock.reset_mock()
# the two values should now be cached
r = yield obj.fn(1, 2)
self.assertEqual(r, 'fish')
r = yield obj.fn(1, 3)
self.assertEqual(r, 'chips')
obj.mock.assert_not_called()
@defer.inlineCallbacks
def test_cache_num_args(self):
"""Only the first num_args arguments should matter to the cache"""
class Cls(object):
def __init__(self):
self.mock = mock.Mock()
@descriptors.cached(num_args=1)
def fn(self, arg1, arg2):
return self.mock(arg1, arg2)
obj = Cls()
obj.mock.return_value = 'fish'
r = yield obj.fn(1, 2)
self.assertEqual(r, 'fish')
obj.mock.assert_called_once_with(1, 2)
obj.mock.reset_mock()
# a call with different params should call the mock again
obj.mock.return_value = 'chips'
r = yield obj.fn(2, 3)
self.assertEqual(r, 'chips')
obj.mock.assert_called_once_with(2, 3)
obj.mock.reset_mock()
# the two values should now be cached; we should be able to vary
# the second argument and still get the cached result.
r = yield obj.fn(1, 4)
self.assertEqual(r, 'fish')
r = yield obj.fn(2, 5)
self.assertEqual(r, 'chips')
obj.mock.assert_not_called()

33
tests/util/test_clock.py Normal file
View file

@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse import util
from twisted.internet import defer
from tests import unittest
class ClockTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_time_bound_deferred(self):
# just a deferred which never resolves
slow_deferred = defer.Deferred()
clock = util.Clock()
time_bound = clock.time_bound_deferred(slow_deferred, 0.001)
try:
yield time_bound
self.fail("Expected timedout error, but got nothing")
except util.DeferredTimedOutError:
pass