Merge branch 'develop' into application-services

This commit is contained in:
Kegan Dougal 2015-02-05 14:28:03 +00:00
commit 951690e54d
27 changed files with 738 additions and 423 deletions

View file

@ -1,3 +1,8 @@
Changes in develop
==================
* pydenticon support -- adds dep on pydenticon
Changes in synapse 0.6.1 (2015-01-07) Changes in synapse 0.6.1 (2015-01-07)
===================================== =====================================

65
scripts/check_auth.py Normal file
View file

@ -0,0 +1,65 @@
from synapse.events import FrozenEvent
from synapse.api.auth import Auth
from mock import Mock
import argparse
import itertools
import json
import sys
def check_auth(auth, auth_chain, events):
auth_chain.sort(key=lambda e: e.depth)
auth_map = {
e.event_id: e
for e in auth_chain
}
create_events = {}
for e in auth_chain:
if e.type == "m.room.create":
create_events[e.room_id] = e
for e in itertools.chain(auth_chain, events):
auth_events_list = [auth_map[i] for i, _ in e.auth_events]
auth_events = {
(e.type, e.state_key): e
for e in auth_events_list
}
auth_events[("m.room.create", "")] = create_events[e.room_id]
try:
auth.check(e, auth_events=auth_events)
except Exception as ex:
print "Failed:", e.event_id, e.type, e.state_key
print "Auth_events:", auth_events
print ex
print json.dumps(e.get_dict(), sort_keys=True, indent=4)
# raise
print "Success:", e.event_id, e.type, e.state_key
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'json',
nargs='?',
type=argparse.FileType('r'),
default=sys.stdin,
)
args = parser.parse_args()
js = json.load(args.json)
auth = Auth(Mock())
check_auth(
auth,
[FrozenEvent(d) for d in js["auth_chain"]],
[FrozenEvent(d) for d in js["pdus"]],
)

View file

@ -102,8 +102,6 @@ class Auth(object):
def check_host_in_room(self, room_id, host): def check_host_in_room(self, room_id, host):
curr_state = yield self.state.get_current_state(room_id) curr_state = yield self.state.get_current_state(room_id)
logger.debug("Got curr_state %s", curr_state)
for event in curr_state: for event in curr_state:
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
try: try:
@ -360,7 +358,7 @@ class Auth(object):
def add_auth_events(self, builder, context): def add_auth_events(self, builder, context):
yield run_on_reactor() yield run_on_reactor()
auth_ids = self.compute_auth_events(builder, context) auth_ids = self.compute_auth_events(builder, context.current_state)
auth_events_entries = yield self.store.add_event_hashes( auth_events_entries = yield self.store.add_event_hashes(
auth_ids auth_ids
@ -374,26 +372,26 @@ class Auth(object):
if v.event_id in auth_ids if v.event_id in auth_ids
} }
def compute_auth_events(self, event, context): def compute_auth_events(self, event, current_state):
if event.type == EventTypes.Create: if event.type == EventTypes.Create:
return [] return []
auth_ids = [] auth_ids = []
key = (EventTypes.PowerLevels, "", ) key = (EventTypes.PowerLevels, "", )
power_level_event = context.current_state.get(key) power_level_event = current_state.get(key)
if power_level_event: if power_level_event:
auth_ids.append(power_level_event.event_id) auth_ids.append(power_level_event.event_id)
key = (EventTypes.JoinRules, "", ) key = (EventTypes.JoinRules, "", )
join_rule_event = context.current_state.get(key) join_rule_event = current_state.get(key)
key = (EventTypes.Member, event.user_id, ) key = (EventTypes.Member, event.user_id, )
member_event = context.current_state.get(key) member_event = current_state.get(key)
key = (EventTypes.Create, "", ) key = (EventTypes.Create, "", )
create_event = context.current_state.get(key) create_event = current_state.get(key)
if create_event: if create_event:
auth_ids.append(create_event.event_id) auth_ids.append(create_event.event_id)

View file

@ -39,7 +39,7 @@ class Codes(object):
TOO_LARGE = "M_TOO_LARGE" TOO_LARGE = "M_TOO_LARGE"
class CodeMessageException(Exception): class CodeMessageException(RuntimeError):
"""An exception with integer code and message string attributes.""" """An exception with integer code and message string attributes."""
def __init__(self, code, msg): def __init__(self, code, msg):
@ -227,3 +227,9 @@ class FederationError(RuntimeError):
"affected": self.affected, "affected": self.affected,
"source": self.source if self.source else self.affected, "source": self.source if self.source else self.affected,
} }
class HttpResponseException(CodeMessageException):
def __init__(self, code, msg, response):
self.response = response
super(HttpResponseException, self).__init__(code, msg)

View file

@ -77,7 +77,7 @@ class EventBase(object):
return self.content["membership"] return self.content["membership"]
def is_state(self): def is_state(self):
return hasattr(self, "state_key") return hasattr(self, "state_key") and self.state_key is not None
def get_dict(self): def get_dict(self):
d = dict(self._event_dict) d = dict(self._event_dict)

View file

@ -0,0 +1,118 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.events.utils import prune_event
from syutil.jsonutil import encode_canonical_json
from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError
import logging
logger = logging.getLogger(__name__)
class FederationBase(object):
@defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False):
"""Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of
that PDU.
If a PDU fails its content hash check then it is redacted.
The given list of PDUs are not modified, instead the function returns
a new list.
Args:
pdu (list)
outlier (bool)
Returns:
Deferred : A list of PDUs that have valid signatures and hashes.
"""
signed_pdus = []
for pdu in pdus:
try:
new_pdu = yield self._check_sigs_and_hash(pdu)
signed_pdus.append(new_pdu)
except SynapseError:
# FIXME: We should handle signature failures more gracefully.
# Check local db.
new_pdu = yield self.store.get_event(
pdu.event_id,
allow_rejected=True
)
if new_pdu:
signed_pdus.append(new_pdu)
continue
# Check pdu.origin
if pdu.origin != origin:
new_pdu = yield self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
outlier=outlier,
)
if new_pdu:
signed_pdus.append(new_pdu)
continue
logger.warn("Failed to find copy of %s with valid signature")
defer.returnValue(signed_pdus)
@defer.inlineCallbacks
def _check_sigs_and_hash(self, pdu):
"""Throws a SynapseError if the PDU does not have the correct
signatures.
Returns:
FrozenEvent: Either the given event or it redacted if it failed the
content hash check.
"""
# Check signatures are correct.
redacted_event = prune_event(pdu)
redacted_pdu_json = redacted_event.get_pdu_json()
try:
yield self.keyring.verify_json_for_server(
pdu.origin, redacted_pdu_json
)
except SynapseError:
logger.warn(
"Signature check failed for %s redacted to %s",
encode_canonical_json(pdu.get_pdu_json()),
encode_canonical_json(redacted_pdu_json),
)
raise
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting %s, %s",
pdu.event_id, encode_canonical_json(pdu.get_dict())
)
defer.returnValue(redacted_event)
defer.returnValue(pdu)

View file

@ -16,17 +16,12 @@
from twisted.internet import defer from twisted.internet import defer
from .federation_base import FederationBase
from .units import Edu from .units import Edu
from synapse.api.errors import CodeMessageException
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from syutil.jsonutil import encode_canonical_json
from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError
import logging import logging
@ -34,7 +29,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FederationClient(object): class FederationClient(FederationBase):
@log_function @log_function
def send_pdu(self, pdu, destinations): def send_pdu(self, pdu, destinations):
"""Informs the replication layer about a new PDU generated within the """Informs the replication layer about a new PDU generated within the
@ -186,7 +181,8 @@ class FederationClient(object):
pdu = yield self._check_sigs_and_hash(pdu) pdu = yield self._check_sigs_and_hash(pdu)
break break
except CodeMessageException:
raise
except Exception as e: except Exception as e:
logger.info( logger.info(
"Failed to get PDU %s from %s because %s", "Failed to get PDU %s from %s because %s",
@ -224,17 +220,17 @@ class FederationClient(object):
for p in result.get("auth_chain", []) for p in result.get("auth_chain", [])
] ]
for i, pdu in enumerate(pdus): signed_pdus = yield self._check_sigs_and_hash_and_fetch(
pdus[i] = yield self._check_sigs_and_hash(pdu) destination, pdus, outlier=True
)
# FIXME: We should handle signature failures more gracefully. signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True
)
for i, pdu in enumerate(auth_chain): signed_auth.sort(key=lambda e: e.depth)
auth_chain[i] = yield self._check_sigs_and_hash(pdu)
# FIXME: We should handle signature failures more gracefully. defer.returnValue((signed_pdus, signed_auth))
defer.returnValue((pdus, auth_chain))
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -248,65 +244,88 @@ class FederationClient(object):
for p in res["auth_chain"] for p in res["auth_chain"]
] ]
for i, pdu in enumerate(auth_chain): signed_auth = yield self._check_sigs_and_hash_and_fetch(
auth_chain[i] = yield self._check_sigs_and_hash(pdu) destination, auth_chain, outlier=True
# FIXME: We should handle signature failures more gracefully.
auth_chain.sort(key=lambda e: e.depth)
defer.returnValue(auth_chain)
@defer.inlineCallbacks
def make_join(self, destination, room_id, user_id):
ret = yield self.transport_layer.make_join(
destination, room_id, user_id
) )
pdu_dict = ret["event"] signed_auth.sort(key=lambda e: e.depth)
logger.debug("Got response to make_join: %s", pdu_dict) defer.returnValue(signed_auth)
defer.returnValue(self.event_from_pdu_json(pdu_dict))
@defer.inlineCallbacks @defer.inlineCallbacks
def send_join(self, destination, pdu): def make_join(self, destinations, room_id, user_id):
time_now = self._clock.time_msec() for destination in destinations:
_, content = yield self.transport_layer.send_join( try:
destination=destination, ret = yield self.transport_layer.make_join(
room_id=pdu.room_id, destination, room_id, user_id
event_id=pdu.event_id, )
content=pdu.get_pdu_json(time_now),
)
logger.debug("Got content: %s", content) pdu_dict = ret["event"]
state = [ logger.debug("Got response to make_join: %s", pdu_dict)
self.event_from_pdu_json(p, outlier=True)
for p in content.get("state", [])
]
auth_chain = [ defer.returnValue(
self.event_from_pdu_json(p, outlier=True) (destination, self.event_from_pdu_json(pdu_dict))
for p in content.get("auth_chain", []) )
] break
except CodeMessageException:
raise
except Exception as e:
logger.warn(
"Failed to make_join via %s: %s",
destination, e.message
)
for i, pdu in enumerate(state): raise RuntimeError("Failed to send to any server.")
state[i] = yield self._check_sigs_and_hash(pdu)
# FIXME: We should handle signature failures more gracefully. @defer.inlineCallbacks
def send_join(self, destinations, pdu):
for destination in destinations:
try:
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
for i, pdu in enumerate(auth_chain): logger.debug("Got content: %s", content)
auth_chain[i] = yield self._check_sigs_and_hash(pdu)
# FIXME: We should handle signature failures more gracefully. state = [
self.event_from_pdu_json(p, outlier=True)
for p in content.get("state", [])
]
auth_chain.sort(key=lambda e: e.depth) auth_chain = [
self.event_from_pdu_json(p, outlier=True)
for p in content.get("auth_chain", [])
]
defer.returnValue({ signed_state = yield self._check_sigs_and_hash_and_fetch(
"state": state, destination, state, outlier=True
"auth_chain": auth_chain, )
})
signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True
)
auth_chain.sort(key=lambda e: e.depth)
defer.returnValue({
"state": signed_state,
"auth_chain": signed_auth,
"origin": destination,
})
except CodeMessageException:
raise
except Exception as e:
logger.warn(
"Failed to send_join via %s: %s",
destination, e.message
)
raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks @defer.inlineCallbacks
def send_invite(self, destination, room_id, event_id, pdu): def send_invite(self, destination, room_id, event_id, pdu):
@ -353,12 +372,18 @@ class FederationClient(object):
) )
auth_chain = [ auth_chain = [
(yield self._check_sigs_and_hash(self.event_from_pdu_json(e))) self.event_from_pdu_json(e)
for e in content["auth_chain"] for e in content["auth_chain"]
] ]
signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True
)
signed_auth.sort(key=lambda e: e.depth)
ret = { ret = {
"auth_chain": auth_chain, "auth_chain": signed_auth,
"rejects": content.get("rejects", []), "rejects": content.get("rejects", []),
"missing": content.get("missing", []), "missing": content.get("missing", []),
} }
@ -373,37 +398,3 @@ class FederationClient(object):
event.internal_metadata.outlier = outlier event.internal_metadata.outlier = outlier
return event return event
@defer.inlineCallbacks
def _check_sigs_and_hash(self, pdu):
"""Throws a SynapseError if the PDU does not have the correct
signatures.
Returns:
FrozenEvent: Either the given event or it redacted if it failed the
content hash check.
"""
# Check signatures are correct.
redacted_event = prune_event(pdu)
redacted_pdu_json = redacted_event.get_pdu_json()
try:
yield self.keyring.verify_json_for_server(
pdu.origin, redacted_pdu_json
)
except SynapseError:
logger.warn(
"Signature check failed for %s redacted to %s",
encode_canonical_json(pdu.get_pdu_json()),
encode_canonical_json(redacted_pdu_json),
)
raise
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting %s, %s",
pdu.event_id, encode_canonical_json(pdu.get_dict())
)
defer.returnValue(redacted_event)
defer.returnValue(pdu)

View file

@ -16,16 +16,12 @@
from twisted.internet import defer from twisted.internet import defer
from .federation_base import FederationBase
from .units import Transaction, Edu from .units import Transaction, Edu
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from syutil.jsonutil import encode_canonical_json
from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import FederationError, SynapseError from synapse.api.errors import FederationError, SynapseError
@ -35,7 +31,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FederationServer(object): class FederationServer(FederationBase):
def set_handler(self, handler): def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate """Sets the handler that the replication layer will use to communicate
receipt of new PDUs from other home servers. The required methods are receipt of new PDUs from other home servers. The required methods are
@ -251,17 +247,20 @@ class FederationServer(object):
Deferred: Results in `dict` with the same format as `content` Deferred: Results in `dict` with the same format as `content`
""" """
auth_chain = [ auth_chain = [
(yield self._check_sigs_and_hash(self.event_from_pdu_json(e))) self.event_from_pdu_json(e)
for e in content["auth_chain"] for e in content["auth_chain"]
] ]
missing = [ signed_auth = yield self._check_sigs_and_hash_and_fetch(
(yield self._check_sigs_and_hash(self.event_from_pdu_json(e))) origin, auth_chain, outlier=True
for e in content.get("missing", []) )
]
ret = yield self.handler.on_query_auth( ret = yield self.handler.on_query_auth(
origin, event_id, auth_chain, content.get("rejects", []), missing origin,
event_id,
signed_auth,
content.get("rejects", []),
content.get("missing", []),
) )
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
@ -426,37 +425,3 @@ class FederationServer(object):
event.internal_metadata.outlier = outlier event.internal_metadata.outlier = outlier
return event return event
@defer.inlineCallbacks
def _check_sigs_and_hash(self, pdu):
"""Throws a SynapseError if the PDU does not have the correct
signatures.
Returns:
FrozenEvent: Either the given event or it redacted if it failed the
content hash check.
"""
# Check signatures are correct.
redacted_event = prune_event(pdu)
redacted_pdu_json = redacted_event.get_pdu_json()
try:
yield self.keyring.verify_json_for_server(
pdu.origin, redacted_pdu_json
)
except SynapseError:
logger.warn(
"Signature check failed for %s redacted to %s",
encode_canonical_json(pdu.get_pdu_json()),
encode_canonical_json(redacted_pdu_json),
)
raise
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting %s, %s",
pdu.event_id, encode_canonical_json(pdu.get_dict())
)
defer.returnValue(redacted_event)
defer.returnValue(pdu)

View file

@ -19,6 +19,7 @@ from twisted.internet import defer
from .persistence import TransactionActions from .persistence import TransactionActions
from .units import Transaction from .units import Transaction
from synapse.api.errors import HttpResponseException
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
@ -238,9 +239,14 @@ class TransactionQueue(object):
del p["age_ts"] del p["age_ts"]
return data return data
code, response = yield self.transport_layer.send_transaction( try:
transaction, json_data_cb response = yield self.transport_layer.send_transaction(
) transaction, json_data_cb
)
code = 200
except HttpResponseException as e:
code = e.code
response = e.response
logger.info("TX [%s] got %d response", destination, code) logger.info("TX [%s] got %d response", destination, code)
@ -274,8 +280,7 @@ class TransactionQueue(object):
pass pass
logger.debug("TX [%s] Yielded to callbacks", destination) logger.debug("TX [%s] Yielded to callbacks", destination)
except RuntimeError as e:
except Exception as e:
# We capture this here as there as nothing actually listens # We capture this here as there as nothing actually listens
# for this finishing functions deferred. # for this finishing functions deferred.
logger.warn( logger.warn(
@ -283,6 +288,14 @@ class TransactionQueue(object):
destination, destination,
e, e,
) )
except Exception as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.exception(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
self.set_retrying(destination, retry_interval) self.set_retrying(destination, retry_interval)

View file

@ -19,7 +19,6 @@ from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
import logging import logging
import json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -129,7 +128,7 @@ class TransportLayerClient(object):
# generated by the json_data_callback. # generated by the json_data_callback.
json_data = transaction.get_dict() json_data = transaction.get_dict()
code, response = yield self.client.put_json( response = yield self.client.put_json(
transaction.destination, transaction.destination,
path=PREFIX + "/send/%s/" % transaction.transaction_id, path=PREFIX + "/send/%s/" % transaction.transaction_id,
data=json_data, data=json_data,
@ -137,95 +136,86 @@ class TransportLayerClient(object):
) )
logger.debug( logger.debug(
"send_data dest=%s, txid=%s, got response: %d", "send_data dest=%s, txid=%s, got response: 200",
transaction.destination, transaction.transaction_id, code transaction.destination, transaction.transaction_id,
) )
defer.returnValue((code, response)) defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def make_query(self, destination, query_type, args, retry_on_dns_fail): def make_query(self, destination, query_type, args, retry_on_dns_fail):
path = PREFIX + "/query/%s" % query_type path = PREFIX + "/query/%s" % query_type
response = yield self.client.get_json( content = yield self.client.get_json(
destination=destination, destination=destination,
path=path, path=path,
args=args, args=args,
retry_on_dns_fail=retry_on_dns_fail, retry_on_dns_fail=retry_on_dns_fail,
) )
defer.returnValue(response) defer.returnValue(content)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def make_join(self, destination, room_id, user_id, retry_on_dns_fail=True): def make_join(self, destination, room_id, user_id, retry_on_dns_fail=True):
path = PREFIX + "/make_join/%s/%s" % (room_id, user_id) path = PREFIX + "/make_join/%s/%s" % (room_id, user_id)
response = yield self.client.get_json( content = yield self.client.get_json(
destination=destination, destination=destination,
path=path, path=path,
retry_on_dns_fail=retry_on_dns_fail, retry_on_dns_fail=retry_on_dns_fail,
) )
defer.returnValue(response) defer.returnValue(content)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def send_join(self, destination, room_id, event_id, content): def send_join(self, destination, room_id, event_id, content):
path = PREFIX + "/send_join/%s/%s" % (room_id, event_id) path = PREFIX + "/send_join/%s/%s" % (room_id, event_id)
code, content = yield self.client.put_json( response = yield self.client.put_json(
destination=destination, destination=destination,
path=path, path=path,
data=content, data=content,
) )
if not 200 <= code < 300:
raise RuntimeError("Got %d from send_join", code)
defer.returnValue(json.loads(content))
@defer.inlineCallbacks
@log_function
def send_invite(self, destination, room_id, event_id, content):
path = PREFIX + "/invite/%s/%s" % (room_id, event_id)
code, content = yield self.client.put_json(
destination=destination,
path=path,
data=content,
)
if not 200 <= code < 300:
raise RuntimeError("Got %d from send_invite", code)
defer.returnValue(json.loads(content))
@defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
path = PREFIX + "/event_auth/%s/%s" % (room_id, event_id)
response = yield self.client.get_json(
destination=destination,
path=path,
)
defer.returnValue(response) defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def send_query_auth(self, destination, room_id, event_id, content): def send_invite(self, destination, room_id, event_id, content):
path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id) path = PREFIX + "/invite/%s/%s" % (room_id, event_id)
code, content = yield self.client.post_json( response = yield self.client.put_json(
destination=destination, destination=destination,
path=path, path=path,
data=content, data=content,
) )
if not 200 <= code < 300: defer.returnValue(response)
raise RuntimeError("Got %d from send_invite", code)
defer.returnValue(json.loads(content)) @defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
path = PREFIX + "/event_auth/%s/%s" % (room_id, event_id)
content = yield self.client.get_json(
destination=destination,
path=path,
)
defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def send_query_auth(self, destination, room_id, event_id, content):
path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id)
content = yield self.client.post_json(
destination=destination,
path=path,
data=content,
)
defer.returnValue(content)

View file

@ -120,7 +120,16 @@ class DirectoryHandler(BaseHandler):
) )
extra_servers = yield self.store.get_joined_hosts_for_room(room_id) extra_servers = yield self.store.get_joined_hosts_for_room(room_id)
servers = list(set(extra_servers) | set(servers)) servers = set(extra_servers) | set(servers)
# If this server is in the list of servers, return it first.
if self.server_name in servers:
servers = (
[self.server_name]
+ [s for s in servers if s != self.server_name]
)
else:
servers = list(servers)
defer.returnValue({ defer.returnValue({
"room_id": room_id, "room_id": room_id,

View file

@ -30,6 +30,7 @@ from synapse.types import UserID
from twisted.internet import defer from twisted.internet import defer
import itertools
import logging import logging
@ -123,8 +124,21 @@ class FederationHandler(BaseHandler):
logger.debug("Got event for room we're not in.") logger.debug("Got event for room we're not in.")
current_state = state current_state = state
event_ids = set()
if state:
event_ids |= {e.event_id for e in state}
if auth_chain:
event_ids |= {e.event_id for e in auth_chain}
seen_ids = (yield self.store.have_events(event_ids)).keys()
if state and auth_chain is not None: if state and auth_chain is not None:
for e in state: # If we have any state or auth_chain given to us by the replication
# layer, then we should handle them (if we haven't before.)
for e in itertools.chain(auth_chain, state):
if e.event_id in seen_ids:
continue
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
try: try:
auth_ids = [e_id for e_id, _ in e.auth_events] auth_ids = [e_id for e_id, _ in e.auth_events]
@ -132,7 +146,10 @@ class FederationHandler(BaseHandler):
(e.type, e.state_key): e for e in auth_chain (e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids if e.event_id in auth_ids
} }
yield self._handle_new_event(origin, e, auth_events=auth) yield self._handle_new_event(
origin, e, auth_events=auth
)
seen_ids.add(e.event_id)
except: except:
logger.exception( logger.exception(
"Failed to handle state event %s", "Failed to handle state event %s",
@ -256,7 +273,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def do_invite_join(self, target_host, room_id, joinee, content, snapshot): def do_invite_join(self, target_hosts, room_id, joinee, content, snapshot):
""" Attempts to join the `joinee` to the room `room_id` via the """ Attempts to join the `joinee` to the room `room_id` via the
server `target_host`. server `target_host`.
@ -270,8 +287,8 @@ class FederationHandler(BaseHandler):
""" """
logger.debug("Joining %s to %s", joinee, room_id) logger.debug("Joining %s to %s", joinee, room_id)
pdu = yield self.replication_layer.make_join( origin, pdu = yield self.replication_layer.make_join(
target_host, target_hosts,
room_id, room_id,
joinee joinee
) )
@ -313,11 +330,17 @@ class FederationHandler(BaseHandler):
new_event = builder.build() new_event = builder.build()
# Try the host we successfully got a response to /make_join/
# request first.
target_hosts.remove(origin)
target_hosts.insert(0, origin)
ret = yield self.replication_layer.send_join( ret = yield self.replication_layer.send_join(
target_host, target_hosts,
new_event new_event
) )
origin = ret["origin"]
state = ret["state"] state = ret["state"]
auth_chain = ret["auth_chain"] auth_chain = ret["auth_chain"]
auth_chain.sort(key=lambda e: e.depth) auth_chain.sort(key=lambda e: e.depth)
@ -354,7 +377,7 @@ class FederationHandler(BaseHandler):
if e.event_id in auth_ids if e.event_id in auth_ids
} }
yield self._handle_new_event( yield self._handle_new_event(
target_host, e, auth_events=auth origin, e, auth_events=auth
) )
except: except:
logger.exception( logger.exception(
@ -374,7 +397,7 @@ class FederationHandler(BaseHandler):
if e.event_id in auth_ids if e.event_id in auth_ids
} }
yield self._handle_new_event( yield self._handle_new_event(
target_host, e, auth_events=auth origin, e, auth_events=auth
) )
except: except:
logger.exception( logger.exception(
@ -389,7 +412,7 @@ class FederationHandler(BaseHandler):
} }
yield self._handle_new_event( yield self._handle_new_event(
target_host, origin,
new_event, new_event,
state=state, state=state,
current_state=state, current_state=state,
@ -498,6 +521,8 @@ class FederationHandler(BaseHandler):
"Failed to get destination from event %s", s.event_id "Failed to get destination from event %s", s.event_id
) )
destinations.remove(origin)
logger.debug( logger.debug(
"on_send_join_request: Sending event: %s, signatures: %s", "on_send_join_request: Sending event: %s, signatures: %s",
event.event_id, event.event_id,
@ -618,6 +643,7 @@ class FederationHandler(BaseHandler):
event = yield self.store.get_event( event = yield self.store.get_event(
event_id, event_id,
allow_none=True, allow_none=True,
allow_rejected=True,
) )
if event: if event:
@ -701,6 +727,8 @@ class FederationHandler(BaseHandler):
context.rejected = RejectedReason.AUTH_ERROR context.rejected = RejectedReason.AUTH_ERROR
# FIXME: Don't store as rejected with AUTH_ERROR if we haven't
# seen all the auth events.
yield self.store.persist_event( yield self.store.persist_event(
event, event,
context=context, context=context,
@ -750,7 +778,7 @@ class FederationHandler(BaseHandler):
) )
) )
logger.debug("on_query_auth reutrning: %s", ret) logger.debug("on_query_auth returning: %s", ret)
defer.returnValue(ret) defer.returnValue(ret)
@ -770,41 +798,45 @@ class FederationHandler(BaseHandler):
if missing_auth: if missing_auth:
logger.debug("Missing auth: %s", missing_auth) logger.debug("Missing auth: %s", missing_auth)
# If we don't have all the auth events, we need to get them. # If we don't have all the auth events, we need to get them.
remote_auth_chain = yield self.replication_layer.get_event_auth( try:
origin, event.room_id, event.event_id remote_auth_chain = yield self.replication_layer.get_event_auth(
) origin, event.room_id, event.event_id
)
seen_remotes = yield self.store.have_events( seen_remotes = yield self.store.have_events(
[e.event_id for e in remote_auth_chain] [e.event_id for e in remote_auth_chain]
) )
for e in remote_auth_chain: for e in remote_auth_chain:
if e.event_id in seen_remotes.keys(): if e.event_id in seen_remotes.keys():
continue continue
if e.event_id == event.event_id: if e.event_id == event.event_id:
continue continue
try: try:
auth_ids = [e_id for e_id, _ in e.auth_events] auth_ids = [e_id for e_id, _ in e.auth_events]
auth = { auth = {
(e.type, e.state_key): e for e in remote_auth_chain (e.type, e.state_key): e for e in remote_auth_chain
if e.event_id in auth_ids if e.event_id in auth_ids
} }
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
logger.debug( logger.debug(
"do_auth %s missing_auth: %s", "do_auth %s missing_auth: %s",
event.event_id, e.event_id event.event_id, e.event_id
) )
yield self._handle_new_event( yield self._handle_new_event(
origin, e, auth_events=auth origin, e, auth_events=auth
) )
if e.event_id in event_auth_events: if e.event_id in event_auth_events:
auth_events[(e.type, e.state_key)] = e auth_events[(e.type, e.state_key)] = e
except AuthError: except AuthError:
pass pass
except:
# FIXME:
logger.exception("Failed to get auth chain")
# FIXME: Assumes we have and stored all the state for all the # FIXME: Assumes we have and stored all the state for all the
# prev_events # prev_events
@ -816,50 +848,57 @@ class FederationHandler(BaseHandler):
logger.debug("Different auth: %s", different_auth) logger.debug("Different auth: %s", different_auth)
# 1. Get what we think is the auth chain. # 1. Get what we think is the auth chain.
auth_ids = self.auth.compute_auth_events(event, context) auth_ids = self.auth.compute_auth_events(
event, context.current_state
)
local_auth_chain = yield self.store.get_auth_chain(auth_ids) local_auth_chain = yield self.store.get_auth_chain(auth_ids)
# 2. Get remote difference. try:
result = yield self.replication_layer.query_auth( # 2. Get remote difference.
origin, result = yield self.replication_layer.query_auth(
event.room_id, origin,
event.event_id, event.room_id,
local_auth_chain, event.event_id,
) local_auth_chain,
)
seen_remotes = yield self.store.have_events( seen_remotes = yield self.store.have_events(
[e.event_id for e in result["auth_chain"]] [e.event_id for e in result["auth_chain"]]
) )
# 3. Process any remote auth chain events we haven't seen. # 3. Process any remote auth chain events we haven't seen.
for ev in result["auth_chain"]: for ev in result["auth_chain"]:
if ev.event_id in seen_remotes.keys(): if ev.event_id in seen_remotes.keys():
continue continue
if ev.event_id == event.event_id: if ev.event_id == event.event_id:
continue continue
try: try:
auth_ids = [e_id for e_id, _ in ev.auth_events] auth_ids = [e_id for e_id, _ in ev.auth_events]
auth = { auth = {
(e.type, e.state_key): e for e in result["auth_chain"] (e.type, e.state_key): e for e in result["auth_chain"]
if e.event_id in auth_ids if e.event_id in auth_ids
} }
ev.internal_metadata.outlier = True ev.internal_metadata.outlier = True
logger.debug( logger.debug(
"do_auth %s different_auth: %s", "do_auth %s different_auth: %s",
event.event_id, e.event_id event.event_id, e.event_id
) )
yield self._handle_new_event( yield self._handle_new_event(
origin, ev, auth_events=auth origin, ev, auth_events=auth
) )
if ev.event_id in event_auth_events: if ev.event_id in event_auth_events:
auth_events[(ev.type, ev.state_key)] = ev auth_events[(ev.type, ev.state_key)] = ev
except AuthError: except AuthError:
pass pass
except:
# FIXME:
logger.exception("Failed to query auth chain")
# 4. Look at rejects and their proofs. # 4. Look at rejects and their proofs.
# TODO. # TODO.
@ -983,7 +1022,7 @@ class FederationHandler(BaseHandler):
if reason is None: if reason is None:
# FIXME: ERRR?! # FIXME: ERRR?!
logger.warn("Could not find reason for %s", e.event_id) logger.warn("Could not find reason for %s", e.event_id)
raise RuntimeError("") raise RuntimeError("Could not find reason for %s" % e.event_id)
reason_map[e.event_id] = reason reason_map[e.event_id] = reason

View file

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import RoomError from synapse.api.errors import RoomError, SynapseError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
@ -372,10 +372,17 @@ class MessageHandler(BaseHandler):
as_event=True, as_event=True,
) )
presence.append(member_presence) presence.append(member_presence)
except Exception: except SynapseError as e:
logger.exception( if e.code == 404:
"Failed to get member presence of %r", m.user_id # FIXME: We are doing this as a warn since this gets hit a
) # lot and spams the logs. Why is this happening?
logger.warn(
"Failed to get member presence of %r", m.user_id
)
else:
logger.exception(
"Failed to get member presence of %r", m.user_id
)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View file

@ -389,8 +389,6 @@ class RoomMemberHandler(BaseHandler):
if not hosts: if not hosts:
raise SynapseError(404, "No known servers") raise SynapseError(404, "No known servers")
host = hosts[0]
# If event doesn't include a display name, add one. # If event doesn't include a display name, add one.
yield self.distributor.fire( yield self.distributor.fire(
"collect_presencelike_data", joinee, content "collect_presencelike_data", joinee, content
@ -407,12 +405,12 @@ class RoomMemberHandler(BaseHandler):
}) })
event, context = yield self._create_new_client_event(builder) event, context = yield self._create_new_client_event(builder)
yield self._do_join(event, context, room_host=host, do_auth=True) yield self._do_join(event, context, room_hosts=hosts, do_auth=True)
defer.returnValue({"room_id": room_id}) defer.returnValue({"room_id": room_id})
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_join(self, event, context, room_host=None, do_auth=True): def _do_join(self, event, context, room_hosts=None, do_auth=True):
joinee = UserID.from_string(event.state_key) joinee = UserID.from_string(event.state_key)
# room_id = RoomID.from_string(event.room_id, self.hs) # room_id = RoomID.from_string(event.room_id, self.hs)
room_id = event.room_id room_id = event.room_id
@ -441,7 +439,7 @@ class RoomMemberHandler(BaseHandler):
if is_host_in_room: if is_host_in_room:
should_do_dance = False should_do_dance = False
elif room_host: # TODO: Shouldn't this be remote_room_host? elif room_hosts: # TODO: Shouldn't this be remote_room_host?
should_do_dance = True should_do_dance = True
else: else:
# TODO(markjh): get prev_state from snapshot # TODO(markjh): get prev_state from snapshot
@ -453,7 +451,7 @@ class RoomMemberHandler(BaseHandler):
inviter = UserID.from_string(prev_state.user_id) inviter = UserID.from_string(prev_state.user_id)
should_do_dance = not self.hs.is_mine(inviter) should_do_dance = not self.hs.is_mine(inviter)
room_host = inviter.domain room_hosts = [inviter.domain]
else: else:
# return the same error as join_room_alias does # return the same error as join_room_alias does
raise SynapseError(404, "No known servers") raise SynapseError(404, "No known servers")
@ -461,7 +459,7 @@ class RoomMemberHandler(BaseHandler):
if should_do_dance: if should_do_dance:
handler = self.hs.get_handlers().federation_handler handler = self.hs.get_handlers().federation_handler
yield handler.do_invite_join( yield handler.do_invite_join(
room_host, room_hosts,
room_id, room_id,
event.user_id, event.user_id,
event.get_dict()["content"], # FIXME To get a non-frozen dict event.get_dict()["content"], # FIXME To get a non-frozen dict

View file

@ -27,7 +27,9 @@ from synapse.util.logcontext import PreserveLoggingContext
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_canonical_json
from synapse.api.errors import CodeMessageException, SynapseError, Codes from synapse.api.errors import (
SynapseError, Codes, HttpResponseException,
)
from syutil.crypto.jsonsign import sign_json from syutil.crypto.jsonsign import sign_json
@ -163,13 +165,12 @@ class MatrixFederationHttpClient(object):
) )
if 200 <= response.code < 300: if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
pass pass
else: else:
# :'( # :'(
# Update transactions table? # Update transactions table?
raise CodeMessageException( raise HttpResponseException(
response.code, response.phrase response.code, response.phrase, response
) )
defer.returnValue(response) defer.returnValue(response)
@ -238,11 +239,20 @@ class MatrixFederationHttpClient(object):
headers_dict={"Content-Type": ["application/json"]}, headers_dict={"Content-Type": ["application/json"]},
) )
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
c_type = response.headers.getRawHeaders("Content-Type")
if "application/json" not in c_type:
raise RuntimeError(
"Content-Type not application/json"
)
logger.debug("Getting resp body") logger.debug("Getting resp body")
body = yield readBody(response) body = yield readBody(response)
logger.debug("Got resp body") logger.debug("Got resp body")
defer.returnValue((response.code, body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def post_json(self, destination, path, data={}): def post_json(self, destination, path, data={}):
@ -275,11 +285,20 @@ class MatrixFederationHttpClient(object):
headers_dict={"Content-Type": ["application/json"]}, headers_dict={"Content-Type": ["application/json"]},
) )
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
c_type = response.headers.getRawHeaders("Content-Type")
if "application/json" not in c_type:
raise RuntimeError(
"Content-Type not application/json"
)
logger.debug("Getting resp body") logger.debug("Getting resp body")
body = yield readBody(response) body = yield readBody(response)
logger.debug("Got resp body") logger.debug("Got resp body")
defer.returnValue((response.code, body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_json(self, destination, path, args={}, retry_on_dns_fail=True): def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
@ -321,7 +340,18 @@ class MatrixFederationHttpClient(object):
retry_on_dns_fail=retry_on_dns_fail retry_on_dns_fail=retry_on_dns_fail
) )
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
c_type = response.headers.getRawHeaders("Content-Type")
if "application/json" not in c_type:
raise RuntimeError(
"Content-Type not application/json"
)
logger.debug("Getting resp body")
body = yield readBody(response) body = yield readBody(response)
logger.debug("Got resp body")
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))

View file

@ -37,14 +37,14 @@ class Pusher(object):
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
def __init__(self, _hs, instance_handle, user_name, app_id, def __init__(self, _hs, profile_tag, user_name, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts, app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since): data, last_token, last_success, failing_since):
self.hs = _hs self.hs = _hs
self.evStreamHandler = self.hs.get_handlers().event_stream_handler self.evStreamHandler = self.hs.get_handlers().event_stream_handler
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.instance_handle = instance_handle self.profile_tag = profile_tag
self.user_name = user_name self.user_name = user_name
self.app_id = app_id self.app_id = app_id
self.app_display_name = app_display_name self.app_display_name = app_display_name
@ -147,9 +147,9 @@ class Pusher(object):
return False return False
return fnmatch.fnmatch(val.upper(), pat.upper()) return fnmatch.fnmatch(val.upper(), pat.upper())
elif condition['kind'] == 'device': elif condition['kind'] == 'device':
if 'instance_handle' not in condition: if 'profile_tag' not in condition:
return True return True
return condition['instance_handle'] == self.instance_handle return condition['profile_tag'] == self.profile_tag
elif condition['kind'] == 'contains_display_name': elif condition['kind'] == 'contains_display_name':
# This is special because display names can be different # This is special because display names can be different
# between rooms and so you can't really hard code it in a rule. # between rooms and so you can't really hard code it in a rule.
@ -400,8 +400,8 @@ def _tweaks_for_actions(actions):
for a in actions: for a in actions:
if not isinstance(a, dict): if not isinstance(a, dict):
continue continue
if 'set_sound' in a: if 'set_tweak' in a and 'value' in a:
tweaks['sound'] = a['set_sound'] tweaks[a['set_tweak']] = a['value']
return tweaks return tweaks

View file

@ -38,7 +38,8 @@ def make_base_rules(user_name):
'actions': [ 'actions': [
'notify', 'notify',
{ {
'set_sound': 'default' 'set_tweak': 'sound',
'value': 'default'
} }
] ]
} }

View file

@ -24,12 +24,12 @@ logger = logging.getLogger(__name__)
class HttpPusher(Pusher): class HttpPusher(Pusher):
def __init__(self, _hs, instance_handle, user_name, app_id, def __init__(self, _hs, profile_tag, user_name, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts, app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since): data, last_token, last_success, failing_since):
super(HttpPusher, self).__init__( super(HttpPusher, self).__init__(
_hs, _hs,
instance_handle, profile_tag,
user_name, user_name,
app_id, app_id,
app_display_name, app_display_name,

View file

@ -55,7 +55,7 @@ class PusherPool:
self._start_pushers(pushers) self._start_pushers(pushers)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_pusher(self, user_name, instance_handle, kind, app_id, def add_pusher(self, user_name, profile_tag, kind, app_id,
app_display_name, device_display_name, pushkey, lang, data): app_display_name, device_display_name, pushkey, lang, data):
# we try to create the pusher just to validate the config: it # we try to create the pusher just to validate the config: it
# will then get pulled out of the database, # will then get pulled out of the database,
@ -64,7 +64,7 @@ class PusherPool:
self._create_pusher({ self._create_pusher({
"user_name": user_name, "user_name": user_name,
"kind": kind, "kind": kind,
"instance_handle": instance_handle, "profile_tag": profile_tag,
"app_id": app_id, "app_id": app_id,
"app_display_name": app_display_name, "app_display_name": app_display_name,
"device_display_name": device_display_name, "device_display_name": device_display_name,
@ -77,18 +77,18 @@ class PusherPool:
"failing_since": None "failing_since": None
}) })
yield self._add_pusher_to_store( yield self._add_pusher_to_store(
user_name, instance_handle, kind, app_id, user_name, profile_tag, kind, app_id,
app_display_name, device_display_name, app_display_name, device_display_name,
pushkey, lang, data pushkey, lang, data
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _add_pusher_to_store(self, user_name, instance_handle, kind, app_id, def _add_pusher_to_store(self, user_name, profile_tag, kind, app_id,
app_display_name, device_display_name, app_display_name, device_display_name,
pushkey, lang, data): pushkey, lang, data):
yield self.store.add_pusher( yield self.store.add_pusher(
user_name=user_name, user_name=user_name,
instance_handle=instance_handle, profile_tag=profile_tag,
kind=kind, kind=kind,
app_id=app_id, app_id=app_id,
app_display_name=app_display_name, app_display_name=app_display_name,
@ -104,7 +104,7 @@ class PusherPool:
if pusherdict['kind'] == 'http': if pusherdict['kind'] == 'http':
return HttpPusher( return HttpPusher(
self.hs, self.hs,
instance_handle=pusherdict['instance_handle'], profile_tag=pusherdict['profile_tag'],
user_name=pusherdict['user_name'], user_name=pusherdict['user_name'],
app_id=pusherdict['app_id'], app_id=pusherdict['app_id'],
app_display_name=pusherdict['app_display_name'], app_display_name=pusherdict['app_display_name'],

View file

@ -73,7 +73,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
'rule_id': rule_id 'rule_id': rule_id
} }
if device: if device:
spec['device'] = device spec['profile_tag'] = device
return spec return spec
def rule_tuple_from_request_object(self, rule_template, rule_id, req_obj, device=None): def rule_tuple_from_request_object(self, rule_template, rule_id, req_obj, device=None):
@ -112,7 +112,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
if device: if device:
conditions.append({ conditions.append({
'kind': 'device', 'kind': 'device',
'instance_handle': device 'profile_tag': device
}) })
if 'actions' not in req_obj: if 'actions' not in req_obj:
@ -188,15 +188,15 @@ class PushRuleRestServlet(ClientV1RestServlet):
user, _ = yield self.auth.get_user_by_req(request) user, _ = yield self.auth.get_user_by_req(request)
if 'device' in spec: if 'profile_tag' in spec:
rules = yield self.hs.get_datastore().get_push_rules_for_user_name( rules = yield self.hs.get_datastore().get_push_rules_for_user_name(
user.to_string() user.to_string()
) )
for r in rules: for r in rules:
conditions = json.loads(r['conditions']) conditions = json.loads(r['conditions'])
ih = _instance_handle_from_conditions(conditions) pt = _profile_tag_from_conditions(conditions)
if ih == spec['device'] and r['priority_class'] == priority_class: if pt == spec['profile_tag'] and r['priority_class'] == priority_class:
yield self.hs.get_datastore().delete_push_rule( yield self.hs.get_datastore().delete_push_rule(
user.to_string(), spec['rule_id'] user.to_string(), spec['rule_id']
) )
@ -239,19 +239,19 @@ class PushRuleRestServlet(ClientV1RestServlet):
if r['priority_class'] > PushRuleRestServlet.PRIORITY_CLASS_MAP['override']: if r['priority_class'] > PushRuleRestServlet.PRIORITY_CLASS_MAP['override']:
# per-device rule # per-device rule
instance_handle = _instance_handle_from_conditions(r["conditions"]) profile_tag = _profile_tag_from_conditions(r["conditions"])
r = _strip_device_condition(r) r = _strip_device_condition(r)
if not instance_handle: if not profile_tag:
continue continue
if instance_handle not in rules['device']: if profile_tag not in rules['device']:
rules['device'][instance_handle] = {} rules['device'][profile_tag] = {}
rules['device'][instance_handle] = ( rules['device'][profile_tag] = (
_add_empty_priority_class_arrays( _add_empty_priority_class_arrays(
rules['device'][instance_handle] rules['device'][profile_tag]
) )
) )
rulearray = rules['device'][instance_handle][template_name] rulearray = rules['device'][profile_tag][template_name]
else: else:
rulearray = rules['global'][template_name] rulearray = rules['global'][template_name]
@ -282,13 +282,13 @@ class PushRuleRestServlet(ClientV1RestServlet):
if path[0] == '': if path[0] == '':
defer.returnValue((200, rules['device'])) defer.returnValue((200, rules['device']))
instance_handle = path[0] profile_tag = path[0]
path = path[1:] path = path[1:]
if instance_handle not in rules['device']: if profile_tag not in rules['device']:
ret = {} ret = {}
ret = _add_empty_priority_class_arrays(ret) ret = _add_empty_priority_class_arrays(ret)
defer.returnValue((200, ret)) defer.returnValue((200, ret))
ruleset = rules['device'][instance_handle] ruleset = rules['device'][profile_tag]
result = _filter_ruleset_with_path(ruleset, path) result = _filter_ruleset_with_path(ruleset, path)
defer.returnValue((200, result)) defer.returnValue((200, result))
else: else:
@ -304,14 +304,14 @@ def _add_empty_priority_class_arrays(d):
return d return d
def _instance_handle_from_conditions(conditions): def _profile_tag_from_conditions(conditions):
""" """
Given a list of conditions, return the instance handle of the Given a list of conditions, return the profile tag of the
device rule if there is one device rule if there is one
""" """
for c in conditions: for c in conditions:
if c['kind'] == 'device': if c['kind'] == 'device':
return c['instance_handle'] return c['profile_tag']
return None return None

View file

@ -41,7 +41,7 @@ class PusherRestServlet(ClientV1RestServlet):
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
reqd = ['instance_handle', 'kind', 'app_id', 'app_display_name', reqd = ['profile_tag', 'kind', 'app_id', 'app_display_name',
'device_display_name', 'pushkey', 'lang', 'data'] 'device_display_name', 'pushkey', 'lang', 'data']
missing = [] missing = []
for i in reqd: for i in reqd:
@ -54,7 +54,7 @@ class PusherRestServlet(ClientV1RestServlet):
try: try:
yield pusher_pool.add_pusher( yield pusher_pool.add_pusher(
user_name=user.to_string(), user_name=user.to_string(),
instance_handle=content['instance_handle'], profile_tag=content['profile_tag'],
kind=content['kind'], kind=content['kind'],
app_id=content['app_id'], app_id=content['app_id'],
app_display_name=content['app_display_name'], app_display_name=content['app_display_name'],

View file

@ -37,7 +37,10 @@ def _get_state_key_from_event(event):
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
AuthEventTypes = (EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,) AuthEventTypes = (
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
EventTypes.JoinRules,
)
class StateHandler(object): class StateHandler(object):
@ -100,7 +103,9 @@ class StateHandler(object):
context.state_group = None context.state_group = None
if hasattr(event, "auth_events") and event.auth_events: if hasattr(event, "auth_events") and event.auth_events:
auth_ids = zip(*event.auth_events)[0] auth_ids = self.hs.get_auth().compute_auth_events(
event, context.current_state
)
context.auth_events = { context.auth_events = {
k: v k: v
for k, v in context.current_state.items() for k, v in context.current_state.items()
@ -146,7 +151,9 @@ class StateHandler(object):
event.unsigned["replaces_state"] = replaces.event_id event.unsigned["replaces_state"] = replaces.event_id
if hasattr(event, "auth_events") and event.auth_events: if hasattr(event, "auth_events") and event.auth_events:
auth_ids = zip(*event.auth_events)[0] auth_ids = self.hs.get_auth().compute_auth_events(
event, context.current_state
)
context.auth_events = { context.auth_events = {
k: v k: v
for k, v in context.current_state.items() for k, v in context.current_state.items()
@ -258,6 +265,15 @@ class StateHandler(object):
auth_events.update(resolved_state) auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key[0] == EventTypes.JoinRules:
resolved_state[key] = self._resolve_auth_events(
events,
auth_events
)
auth_events.update(resolved_state)
for key, events in conflicted_state.items(): for key, events in conflicted_state.items():
if key[0] == EventTypes.Member: if key[0] == EventTypes.Member:
resolved_state[key] = self._resolve_auth_events( resolved_state[key] = self._resolve_auth_events(

View file

@ -131,21 +131,144 @@ class DataStore(RoomMemberStore, RoomStore,
pass pass
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event(self, event_id, allow_none=False): def get_event(self, event_id, check_redacted=True,
events = yield self._get_events([event_id]) get_prev_content=False, allow_rejected=False,
allow_none=False):
"""Get an event from the database by event_id.
if not events: Args:
if allow_none: event_id (str): The event_id of the event to fetch
defer.returnValue(None) check_redacted (bool): If True, check if event has been redacted
else: and redact it.
raise RuntimeError("Could not find event %s" % (event_id,)) get_prev_content (bool): If True and event is a state event,
include the previous states content in the unsigned field.
allow_rejected (bool): If True return rejected events.
allow_none (bool): If True, return None if no event found, if
False throw an exception.
defer.returnValue(events[0]) Returns:
Deferred : A FrozenEvent.
"""
event = yield self.runInteraction(
"get_event", self._get_event_txn,
event_id,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
if not event and not allow_none:
raise RuntimeError("Could not find event %s" % (event_id,))
defer.returnValue(event)
@log_function @log_function
def _persist_event_txn(self, txn, event, context, backfilled, def _persist_event_txn(self, txn, event, context, backfilled,
stream_ordering=None, is_new_state=True, stream_ordering=None, is_new_state=True,
current_state=None): current_state=None):
# We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table
if current_state:
txn.execute(
"DELETE FROM current_state_events WHERE room_id = ?",
(event.room_id,)
)
for s in current_state:
self._simple_insert_txn(
txn,
"current_state_events",
{
"event_id": s.event_id,
"room_id": s.room_id,
"type": s.type,
"state_key": s.state_key,
},
or_replace=True,
)
if event.is_state() and is_new_state:
if not backfilled and not context.rejected:
self._simple_insert_txn(
txn,
table="state_forward_extremities",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
or_replace=True,
)
for prev_state_id, _ in event.prev_state:
self._simple_delete_txn(
txn,
table="state_forward_extremities",
keyvalues={
"event_id": prev_state_id,
}
)
outlier = event.internal_metadata.is_outlier()
if not outlier:
self._store_state_groups_txn(txn, event, context)
self._update_min_depth_for_room_txn(
txn,
event.room_id,
event.depth
)
self._handle_prev_events(
txn,
outlier=outlier,
event_id=event.event_id,
prev_events=event.prev_events,
room_id=event.room_id,
)
have_persisted = self._simple_select_one_onecol_txn(
txn,
table="event_json",
keyvalues={"event_id": event.event_id},
retcol="event_id",
allow_none=True,
)
metadata_json = encode_canonical_json(
event.internal_metadata.get_dict()
)
# If we have already persisted this event, we don't need to do any
# more processing.
# The processing above must be done on every call to persist event,
# since they might not have happened on previous calls. For example,
# if we are persisting an event that we had persisted as an outlier,
# but is no longer one.
if have_persisted:
if not outlier:
sql = (
"UPDATE event_json SET internal_metadata = ?"
" WHERE event_id = ?"
)
txn.execute(
sql,
(metadata_json.decode("UTF-8"), event.event_id,)
)
sql = (
"UPDATE events SET outlier = 0"
" WHERE event_id = ?"
)
txn.execute(
sql,
(event.event_id,)
)
return
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
self._store_room_member_txn(txn, event) self._store_room_member_txn(txn, event)
elif event.type == EventTypes.Feedback: elif event.type == EventTypes.Feedback:
@ -157,8 +280,6 @@ class DataStore(RoomMemberStore, RoomStore,
elif event.type == EventTypes.Redaction: elif event.type == EventTypes.Redaction:
self._store_redaction(txn, event) self._store_redaction(txn, event)
outlier = event.internal_metadata.is_outlier()
event_dict = { event_dict = {
k: v k: v
for k, v in event.get_dict().items() for k, v in event.get_dict().items()
@ -168,10 +289,6 @@ class DataStore(RoomMemberStore, RoomStore,
] ]
} }
metadata_json = encode_canonical_json(
event.internal_metadata.get_dict()
)
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
table="event_json", table="event_json",
@ -227,41 +344,10 @@ class DataStore(RoomMemberStore, RoomStore,
) )
raise _RollbackButIsFineException("_persist_event") raise _RollbackButIsFineException("_persist_event")
self._handle_prev_events(
txn,
outlier=outlier,
event_id=event.event_id,
prev_events=event.prev_events,
room_id=event.room_id,
)
if not outlier:
self._store_state_groups_txn(txn, event, context)
if context.rejected: if context.rejected:
self._store_rejections_txn(txn, event.event_id, context.rejected) self._store_rejections_txn(txn, event.event_id, context.rejected)
if current_state: if event.is_state():
txn.execute(
"DELETE FROM current_state_events WHERE room_id = ?",
(event.room_id,)
)
for s in current_state:
self._simple_insert_txn(
txn,
"current_state_events",
{
"event_id": s.event_id,
"room_id": s.room_id,
"type": s.type,
"state_key": s.state_key,
},
or_replace=True,
)
is_state = hasattr(event, "state_key") and event.state_key is not None
if is_state:
vals = { vals = {
"event_id": event.event_id, "event_id": event.event_id,
"room_id": event.room_id, "room_id": event.room_id,
@ -269,6 +355,7 @@ class DataStore(RoomMemberStore, RoomStore,
"state_key": event.state_key, "state_key": event.state_key,
} }
# TODO: How does this work with backfilling?
if hasattr(event, "replaces_state"): if hasattr(event, "replaces_state"):
vals["prev_state"] = event.replaces_state vals["prev_state"] = event.replaces_state
@ -305,28 +392,6 @@ class DataStore(RoomMemberStore, RoomStore,
or_ignore=True, or_ignore=True,
) )
if not backfilled and not context.rejected:
self._simple_insert_txn(
txn,
table="state_forward_extremities",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
or_replace=True,
)
for prev_state_id, _ in event.prev_state:
self._simple_delete_txn(
txn,
table="state_forward_extremities",
keyvalues={
"event_id": prev_state_id,
}
)
for hash_alg, hash_base64 in event.hashes.items(): for hash_alg, hash_base64 in event.hashes.items():
hash_bytes = decode_base64(hash_base64) hash_bytes = decode_base64(hash_base64)
self._store_event_content_hash_txn( self._store_event_content_hash_txn(
@ -357,13 +422,6 @@ class DataStore(RoomMemberStore, RoomStore,
txn, event.event_id, ref_alg, ref_hash_bytes txn, event.event_id, ref_alg, ref_hash_bytes
) )
if not outlier:
self._update_min_depth_for_room_txn(
txn,
event.room_id,
event.depth
)
def _store_redaction(self, txn, event): def _store_redaction(self, txn, event):
txn.execute( txn.execute(
"INSERT OR IGNORE INTO redactions " "INSERT OR IGNORE INTO redactions "
@ -480,6 +538,9 @@ class DataStore(RoomMemberStore, RoomStore,
the rejected reason string if we rejected the event, else maps to the rejected reason string if we rejected the event, else maps to
None. None.
""" """
if not event_ids:
return defer.succeed({})
def f(txn): def f(txn):
sql = ( sql = (
"SELECT e.event_id, reason FROM events as e " "SELECT e.event_id, reason FROM events as e "

View file

@ -29,7 +29,7 @@ class PusherStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_pushers_by_app_id_and_pushkey(self, app_id_and_pushkey): def get_pushers_by_app_id_and_pushkey(self, app_id_and_pushkey):
sql = ( sql = (
"SELECT id, user_name, kind, instance_handle, app_id," "SELECT id, user_name, kind, profile_tag, app_id,"
"app_display_name, device_display_name, pushkey, ts, data, " "app_display_name, device_display_name, pushkey, ts, data, "
"last_token, last_success, failing_since " "last_token, last_success, failing_since "
"FROM pushers " "FROM pushers "
@ -45,7 +45,7 @@ class PusherStore(SQLBaseStore):
"id": r[0], "id": r[0],
"user_name": r[1], "user_name": r[1],
"kind": r[2], "kind": r[2],
"instance_handle": r[3], "profile_tag": r[3],
"app_id": r[4], "app_id": r[4],
"app_display_name": r[5], "app_display_name": r[5],
"device_display_name": r[6], "device_display_name": r[6],
@ -64,7 +64,7 @@ class PusherStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_all_pushers(self): def get_all_pushers(self):
sql = ( sql = (
"SELECT id, user_name, kind, instance_handle, app_id," "SELECT id, user_name, kind, profile_tag, app_id,"
"app_display_name, device_display_name, pushkey, ts, data, " "app_display_name, device_display_name, pushkey, ts, data, "
"last_token, last_success, failing_since " "last_token, last_success, failing_since "
"FROM pushers" "FROM pushers"
@ -77,7 +77,7 @@ class PusherStore(SQLBaseStore):
"id": r[0], "id": r[0],
"user_name": r[1], "user_name": r[1],
"kind": r[2], "kind": r[2],
"instance_handle": r[3], "profile_tag": r[3],
"app_id": r[4], "app_id": r[4],
"app_display_name": r[5], "app_display_name": r[5],
"device_display_name": r[6], "device_display_name": r[6],
@ -94,7 +94,7 @@ class PusherStore(SQLBaseStore):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_pusher(self, user_name, instance_handle, kind, app_id, def add_pusher(self, user_name, profile_tag, kind, app_id,
app_display_name, device_display_name, app_display_name, device_display_name,
pushkey, pushkey_ts, lang, data): pushkey, pushkey_ts, lang, data):
try: try:
@ -107,7 +107,7 @@ class PusherStore(SQLBaseStore):
dict( dict(
user_name=user_name, user_name=user_name,
kind=kind, kind=kind,
instance_handle=instance_handle, profile_tag=profile_tag,
app_display_name=app_display_name, app_display_name=app_display_name,
device_display_name=device_display_name, device_display_name=device_display_name,
ts=pushkey_ts, ts=pushkey_ts,
@ -158,7 +158,7 @@ class PushersTable(Table):
"id", "id",
"user_name", "user_name",
"kind", "kind",
"instance_handle", "profile_tag",
"app_id", "app_id",
"app_display_name", "app_display_name",
"device_display_name", "device_display_name",

View file

@ -24,7 +24,7 @@ CREATE TABLE IF NOT EXISTS rejections(
CREATE TABLE IF NOT EXISTS pushers ( CREATE TABLE IF NOT EXISTS pushers (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
user_name TEXT NOT NULL, user_name TEXT NOT NULL,
instance_handle varchar(32) NOT NULL, profile_tag varchar(32) NOT NULL,
kind varchar(8) NOT NULL, kind varchar(8) NOT NULL,
app_id varchar(64) NOT NULL, app_id varchar(64) NOT NULL,
app_display_name varchar(64) NOT NULL, app_display_name varchar(64) NOT NULL,

View file

@ -16,7 +16,7 @@
CREATE TABLE IF NOT EXISTS pushers ( CREATE TABLE IF NOT EXISTS pushers (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
user_name TEXT NOT NULL, user_name TEXT NOT NULL,
instance_handle varchar(32) NOT NULL, profile_tag varchar(32) NOT NULL,
kind varchar(8) NOT NULL, kind varchar(8) NOT NULL,
app_id varchar(64) NOT NULL, app_id varchar(64) NOT NULL,
app_display_name varchar(64) NOT NULL, app_display_name varchar(64) NOT NULL,

View file

@ -91,7 +91,10 @@ class FederationTestCase(unittest.TestCase):
self.datastore.persist_event.return_value = defer.succeed(None) self.datastore.persist_event.return_value = defer.succeed(None)
self.datastore.get_room.return_value = defer.succeed(True) self.datastore.get_room.return_value = defer.succeed(True)
self.auth.check_host_in_room.return_value = defer.succeed(True) self.auth.check_host_in_room.return_value = defer.succeed(True)
self.datastore.have_events.return_value = defer.succeed({})
def have_events(event_ids):
return defer.succeed({})
self.datastore.have_events.side_effect = have_events
def annotate(ev, old_state=None): def annotate(ev, old_state=None):
context = Mock() context = Mock()