Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2017-01-20 14:00:31 +00:00
commit e41f183aa8
47 changed files with 1458 additions and 1002 deletions

View file

@ -138,6 +138,7 @@ Installing prerequisites on openSUSE::
python-devel libffi-devel libopenssl-devel libjpeg62-devel
Installing prerequisites on OpenBSD::
doas pkg_add python libffi py-pip py-setuptools sqlite3 py-virtualenv \
libxslt

View file

@ -16,18 +16,14 @@
import logging
import pymacaroons
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException
from twisted.internet import defer
from unpaddedbase64 import decode_base64
import synapse.types
from synapse import event_auth
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import UserID, get_domain_from_id
from synapse.api.errors import AuthError, Codes
from synapse.types import UserID
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@ -78,147 +74,7 @@ class Auth(object):
True if the auth checks pass.
"""
with Measure(self.clock, "auth.check"):
self.check_size_limits(event)
if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event)
if do_sig_check:
sender_domain = get_domain_from_id(event.sender)
event_id_domain = get_domain_from_id(event.event_id)
is_invite_via_3pid = (
event.type == EventTypes.Member
and event.membership == Membership.INVITE
and "third_party_invite" in event.content
)
# Check the sender's domain has signed the event
if not event.signatures.get(sender_domain):
# We allow invites via 3pid to have a sender from a different
# HS, as the sender must match the sender of the original
# 3pid invite. This is checked further down with the
# other dedicated membership checks.
if not is_invite_via_3pid:
raise AuthError(403, "Event not signed by sender's server")
# Check the event_id's domain has signed the event
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
if auth_events is None:
# Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now)
logger.warn("Trusting event: %s", event.event_id)
return True
if event.type == EventTypes.Create:
room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain:
raise AuthError(
403,
"Creation event's room_id domain does not match sender's"
)
# FIXME
return True
creation_event = auth_events.get((EventTypes.Create, ""), None)
if not creation_event:
raise SynapseError(
403,
"Room %r does not exist" % (event.room_id,)
)
creating_domain = get_domain_from_id(event.room_id)
originating_domain = get_domain_from_id(event.sender)
if creating_domain != originating_domain:
if not self.can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
# FIXME: Temp hack
if event.type == EventTypes.Aliases:
if not event.is_state():
raise AuthError(
403,
"Alias event must be a state event",
)
if not event.state_key:
raise AuthError(
403,
"Alias event must have non-empty state_key"
)
sender_domain = get_domain_from_id(event.sender)
if event.state_key != sender_domain:
raise AuthError(
403,
"Alias event's state_key does not match sender's domain"
)
return True
logger.debug(
"Auth events: %s",
[a.event_id for a in auth_events.values()]
)
if event.type == EventTypes.Member:
allowed = self.is_membership_change_allowed(
event, auth_events
)
if allowed:
logger.debug("Allowing! %s", event)
else:
logger.debug("Denying! %s", event)
return allowed
self.check_event_sender_in_room(event, auth_events)
# Special case to allow m.room.third_party_invite events wherever
# a user is allowed to issue invites. Fixes
# https://github.com/vector-im/vector-web/issues/1208 hopefully
if event.type == EventTypes.ThirdPartyInvite:
user_level = self._get_user_power_level(event.user_id, auth_events)
invite_level = self._get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, (
"You cannot issue a third party invite for %s." %
(event.content.display_name,)
)
)
else:
return True
self._can_send_event(event, auth_events)
if event.type == EventTypes.PowerLevels:
self._check_power_levels(event, auth_events)
if event.type == EventTypes.Redaction:
self.check_redaction(event, auth_events)
logger.debug("Allowing! %s", event)
def check_size_limits(self, event):
def too_big(field):
raise EventSizeError("%s too large" % (field,))
if len(event.user_id) > 255:
too_big("user_id")
if len(event.room_id) > 255:
too_big("room_id")
if event.is_state() and len(event.state_key) > 255:
too_big("state_key")
if len(event.type) > 255:
too_big("type")
if len(event.event_id) > 255:
too_big("event_id")
if len(encode_canonical_json(event.get_pdu_json())) > 65536:
too_big("event")
event_auth.check(event, auth_events, do_sig_check=do_sig_check)
@defer.inlineCallbacks
def check_joined_room(self, room_id, user_id, current_state=None):
@ -290,7 +146,7 @@ class Auth(object):
with Measure(self.clock, "check_host_in_room"):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.info("calling resolve_state_groups from check_host_in_room")
logger.debug("calling resolve_state_groups from check_host_in_room")
entry = yield self.state.resolve_state_groups(
room_id, latest_event_ids
)
@ -300,16 +156,6 @@ class Auth(object):
)
defer.returnValue(ret)
def check_event_sender_in_room(self, event, auth_events):
key = (EventTypes.Member, event.user_id, )
member_event = auth_events.get(key)
return self._check_joined_room(
member_event,
event.user_id,
event.room_id
)
def _check_joined_room(self, member, user_id, room_id):
if not member or member.membership != Membership.JOIN:
raise AuthError(403, "User %s not in room %s (%s)" % (
@ -321,267 +167,8 @@ class Auth(object):
return creation_event.content.get("m.federate", True) is True
@log_function
def is_membership_change_allowed(self, event, auth_events):
membership = event.content["membership"]
# Check if this is the room creator joining:
if len(event.prev_events) == 1 and Membership.JOIN == membership:
# Get room creation event:
key = (EventTypes.Create, "", )
create = auth_events.get(key)
if create and event.prev_events[0][0] == create.event_id:
if create.content["creator"] == event.state_key:
return True
target_user_id = event.state_key
creating_domain = get_domain_from_id(event.room_id)
target_domain = get_domain_from_id(target_user_id)
if creating_domain != target_domain:
if not self.can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
# get info about the caller
key = (EventTypes.Member, event.user_id, )
caller = auth_events.get(key)
caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE
# get info about the target
key = (EventTypes.Member, target_user_id, )
target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN
target_banned = target and target.membership == Membership.BAN
key = (EventTypes.JoinRules, "", )
join_rule_event = auth_events.get(key)
if join_rule_event:
join_rule = join_rule_event.content.get(
"join_rule", JoinRules.INVITE
)
else:
join_rule = JoinRules.INVITE
user_level = self._get_user_power_level(event.user_id, auth_events)
target_level = self._get_user_power_level(
target_user_id, auth_events
)
# FIXME (erikj): What should we do here as the default?
ban_level = self._get_named_level(auth_events, "ban", 50)
logger.debug(
"is_membership_change_allowed: %s",
{
"caller_in_room": caller_in_room,
"caller_invited": caller_invited,
"target_banned": target_banned,
"target_in_room": target_in_room,
"membership": membership,
"join_rule": join_rule,
"target_user_id": target_user_id,
"event.user_id": event.user_id,
}
)
if Membership.INVITE == membership and "third_party_invite" in event.content:
if not self._verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.")
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
return True
if Membership.JOIN != membership:
if (caller_invited
and Membership.LEAVE == membership
and target_user_id == event.user_id):
return True
if not caller_in_room: # caller isn't joined
raise AuthError(
403,
"%s not in room %s." % (event.user_id, event.room_id,)
)
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
# PRIVATE join rules.
# Invites are valid iff caller is in the room and target isn't.
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." %
target_user_id)
else:
invite_level = self._get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, "You cannot invite user %s." % target_user_id
)
elif Membership.JOIN == membership:
# Joins are valid iff caller == target and they were:
# invited: They are accepting the invitation
# joined: It's a NOOP
if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.")
elif target_banned:
raise AuthError(403, "You are banned from this room")
elif join_rule == JoinRules.PUBLIC:
pass
elif join_rule == JoinRules.INVITE:
if not caller_in_room and not caller_invited:
raise AuthError(403, "You are not invited to this room.")
else:
# TODO (erikj): may_join list
# TODO (erikj): private rooms
raise AuthError(403, "You are not allowed to join this room")
elif Membership.LEAVE == membership:
# TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level:
raise AuthError(
403, "You cannot unban user &s." % (target_user_id,)
)
elif target_user_id != event.user_id:
kick_level = self._get_named_level(auth_events, "kick", 50)
if user_level < kick_level or user_level <= target_level:
raise AuthError(
403, "You cannot kick user %s." % target_user_id
)
elif Membership.BAN == membership:
if user_level < ban_level or user_level <= target_level:
raise AuthError(403, "You don't have permission to ban")
else:
raise AuthError(500, "Unknown membership %s" % membership)
return True
def _verify_third_party_invite(self, event, auth_events):
"""
Validates that the invite event is authorized by a previous third-party invite.
Checks that the public key, and keyserver, match those in the third party invite,
and that the invite event has a signature issued using that public key.
Args:
event: The m.room.member join event being validated.
auth_events: All relevant previous context events which may be used
for authorization decisions.
Return:
True if the event fulfills the expectations of a previous third party
invite event.
"""
if "third_party_invite" not in event.content:
return False
if "signed" not in event.content["third_party_invite"]:
return False
signed = event.content["third_party_invite"]["signed"]
for key in {"mxid", "token"}:
if key not in signed:
return False
token = signed["token"]
invite_event = auth_events.get(
(EventTypes.ThirdPartyInvite, token,)
)
if not invite_event:
return False
if invite_event.sender != event.sender:
return False
if event.user_id != invite_event.user_id:
return False
if signed["mxid"] != event.state_key:
return False
if signed["token"] != token:
return False
for public_key_object in self.get_public_keys(invite_event):
public_key = public_key_object["public_key"]
try:
for server, signature_block in signed["signatures"].items():
for key_name, encoded_signature in signature_block.items():
if not key_name.startswith("ed25519:"):
continue
verify_key = decode_verify_key_bytes(
key_name,
decode_base64(public_key)
)
verify_signed_json(signed, server, verify_key)
# We got the public key from the invite, so we know that the
# correct server signed the signed bundle.
# The caller is responsible for checking that the signing
# server has not revoked that public key.
return True
except (KeyError, SignatureVerifyException,):
continue
return False
def get_public_keys(self, invite_event):
public_keys = []
if "public_key" in invite_event.content:
o = {
"public_key": invite_event.content["public_key"],
}
if "key_validity_url" in invite_event.content:
o["key_validity_url"] = invite_event.content["key_validity_url"]
public_keys.append(o)
public_keys.extend(invite_event.content.get("public_keys", []))
return public_keys
def _get_power_level_event(self, auth_events):
key = (EventTypes.PowerLevels, "", )
return auth_events.get(key)
def _get_user_power_level(self, user_id, auth_events):
power_level_event = self._get_power_level_event(auth_events)
if power_level_event:
level = power_level_event.content.get("users", {}).get(user_id)
if not level:
level = power_level_event.content.get("users_default", 0)
if level is None:
return 0
else:
return int(level)
else:
key = (EventTypes.Create, "", )
create_event = auth_events.get(key)
if (create_event is not None and
create_event.content["creator"] == user_id):
return 100
else:
return 0
def _get_named_level(self, auth_events, name, default):
power_level_event = self._get_power_level_event(auth_events)
if not power_level_event:
return default
level = power_level_event.content.get(name, None)
if level is not None:
return int(level)
else:
return default
return event_auth.get_public_keys(invite_event)
@defer.inlineCallbacks
def get_user_by_req(self, request, allow_guest=False, rights="access"):
@ -974,56 +561,6 @@ class Auth(object):
defer.returnValue(auth_ids)
def _get_send_level(self, etype, state_key, auth_events):
key = (EventTypes.PowerLevels, "", )
send_level_event = auth_events.get(key)
send_level = None
if send_level_event:
send_level = send_level_event.content.get("events", {}).get(
etype
)
if send_level is None:
if state_key is not None:
send_level = send_level_event.content.get(
"state_default", 50
)
else:
send_level = send_level_event.content.get(
"events_default", 0
)
if send_level:
send_level = int(send_level)
else:
send_level = 0
return send_level
@log_function
def _can_send_event(self, event, auth_events):
send_level = self._get_send_level(
event.type, event.get("state_key", None), auth_events
)
user_level = self._get_user_power_level(event.user_id, auth_events)
if user_level < send_level:
raise AuthError(
403,
"You don't have permission to post that to the room. " +
"user_level (%d) < send_level (%d)" % (user_level, send_level)
)
# Check state_key
if hasattr(event, "state_key"):
if event.state_key.startswith("@"):
if event.state_key != event.user_id:
raise AuthError(
403,
"You are not allowed to set others state"
)
return True
def check_redaction(self, event, auth_events):
"""Check whether the event sender is allowed to redact the target event.
@ -1037,107 +574,7 @@ class Auth(object):
AuthError if the event sender is definitely not allowed to redact
the target event.
"""
user_level = self._get_user_power_level(event.user_id, auth_events)
redact_level = self._get_named_level(auth_events, "redact", 50)
if user_level >= redact_level:
return False
redacter_domain = get_domain_from_id(event.event_id)
redactee_domain = get_domain_from_id(event.redacts)
if redacter_domain == redactee_domain:
return True
raise AuthError(
403,
"You don't have permission to redact events"
)
def _check_power_levels(self, event, auth_events):
user_list = event.content.get("users", {})
# Validate users
for k, v in user_list.items():
try:
UserID.from_string(k)
except:
raise SynapseError(400, "Not a valid user_id: %s" % (k,))
try:
int(v)
except:
raise SynapseError(400, "Not a valid power level: %s" % (v,))
key = (event.type, event.state_key, )
current_state = auth_events.get(key)
if not current_state:
return
user_level = self._get_user_power_level(event.user_id, auth_events)
# Check other levels:
levels_to_check = [
("users_default", None),
("events_default", None),
("state_default", None),
("ban", None),
("redact", None),
("kick", None),
("invite", None),
]
old_list = current_state.content.get("users")
for user in set(old_list.keys() + user_list.keys()):
levels_to_check.append(
(user, "users")
)
old_list = current_state.content.get("events")
new_list = event.content.get("events")
for ev_id in set(old_list.keys() + new_list.keys()):
levels_to_check.append(
(ev_id, "events")
)
old_state = current_state.content
new_state = event.content
for level_to_check, dir in levels_to_check:
old_loc = old_state
new_loc = new_state
if dir:
old_loc = old_loc.get(dir, {})
new_loc = new_loc.get(dir, {})
if level_to_check in old_loc:
old_level = int(old_loc[level_to_check])
else:
old_level = None
if level_to_check in new_loc:
new_level = int(new_loc[level_to_check])
else:
new_level = None
if new_level is not None and old_level is not None:
if new_level == old_level:
continue
if dir == "users" and level_to_check != event.user_id:
if old_level == user_level:
raise AuthError(
403,
"You don't have permission to remove ops level equal "
"to your own"
)
if old_level > user_level or new_level > user_level:
raise AuthError(
403,
"You don't have permission to add ops level greater "
"than your own"
)
return event_auth.check_redaction(event, auth_events)
@defer.inlineCallbacks
def check_can_change_room_list(self, room_id, user):
@ -1167,10 +604,10 @@ class Auth(object):
if power_level_event:
auth_events[(EventTypes.PowerLevels, "")] = power_level_event
send_level = self._get_send_level(
send_level = event_auth.get_send_level(
EventTypes.Aliases, "", auth_events
)
user_level = self._get_user_power_level(user_id, auth_events)
user_level = event_auth.get_user_power_level(user_id, auth_events)
if user_level < send_level:
raise AuthError(

View file

@ -76,8 +76,7 @@ class AppserviceServer(HomeServer):
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", None)
bind_addresses = listener_config.get("bind_addresses", [])
bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
@ -87,9 +86,6 @@ class AppserviceServer(HomeServer):
root_resource = create_resource_tree(resources, Resource())
if bind_address is not None:
bind_addresses.append(bind_address)
for address in bind_addresses:
reactor.listenTCP(
port,
@ -109,11 +105,7 @@ class AppserviceServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
bind_address = listener.get("bind_address", None)
bind_addresses = listener.get("bind_addresses", [])
if bind_address is not None:
bind_addresses.append(bind_address)
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(

View file

@ -90,8 +90,7 @@ class ClientReaderServer(HomeServer):
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", None)
bind_addresses = listener_config.get("bind_addresses", [])
bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
@ -110,9 +109,6 @@ class ClientReaderServer(HomeServer):
root_resource = create_resource_tree(resources, Resource())
if bind_address is not None:
bind_addresses.append(bind_address)
for address in bind_addresses:
reactor.listenTCP(
port,
@ -132,11 +128,7 @@ class ClientReaderServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
bind_address = listener.get("bind_address", None)
bind_addresses = listener.get("bind_addresses", [])
if bind_address is not None:
bind_addresses.append(bind_address)
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(

View file

@ -86,8 +86,7 @@ class FederationReaderServer(HomeServer):
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", None)
bind_addresses = listener_config.get("bind_addresses", [])
bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
@ -101,9 +100,6 @@ class FederationReaderServer(HomeServer):
root_resource = create_resource_tree(resources, Resource())
if bind_address is not None:
bind_addresses.append(bind_address)
for address in bind_addresses:
reactor.listenTCP(
port,
@ -123,11 +119,7 @@ class FederationReaderServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
bind_address = listener.get("bind_address", None)
bind_addresses = listener.get("bind_addresses", [])
if bind_address is not None:
bind_addresses.append(bind_address)
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(

View file

@ -82,8 +82,7 @@ class FederationSenderServer(HomeServer):
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", None)
bind_addresses = listener_config.get("bind_addresses", [])
bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
@ -93,9 +92,6 @@ class FederationSenderServer(HomeServer):
root_resource = create_resource_tree(resources, Resource())
if bind_address is not None:
bind_addresses.append(bind_address)
for address in bind_addresses:
reactor.listenTCP(
port,
@ -115,11 +111,7 @@ class FederationSenderServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
bind_address = listener.get("bind_address", None)
bind_addresses = listener.get("bind_addresses", [])
if bind_address is not None:
bind_addresses.append(bind_address)
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(

View file

@ -107,8 +107,7 @@ def build_resource_for_web_client(hs):
class SynapseHomeServer(HomeServer):
def _listener_http(self, config, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", None)
bind_addresses = listener_config.get("bind_addresses", [])
bind_addresses = listener_config["bind_addresses"]
tls = listener_config.get("tls", False)
site_tag = listener_config.get("tag", port)
@ -175,9 +174,6 @@ class SynapseHomeServer(HomeServer):
root_resource = create_resource_tree(resources, root_resource)
if bind_address is not None:
bind_addresses.append(bind_address)
if tls:
for address in bind_addresses:
reactor.listenSSL(
@ -212,11 +208,7 @@ class SynapseHomeServer(HomeServer):
if listener["type"] == "http":
self._listener_http(config, listener)
elif listener["type"] == "manhole":
bind_address = listener.get("bind_address", None)
bind_addresses = listener.get("bind_addresses", [])
if bind_address is not None:
bind_addresses.append(bind_address)
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(

View file

@ -87,8 +87,7 @@ class MediaRepositoryServer(HomeServer):
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", None)
bind_addresses = listener_config.get("bind_addresses", [])
bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
@ -107,9 +106,6 @@ class MediaRepositoryServer(HomeServer):
root_resource = create_resource_tree(resources, Resource())
if bind_address is not None:
bind_addresses.append(bind_address)
for address in bind_addresses:
reactor.listenTCP(
port,
@ -129,11 +125,7 @@ class MediaRepositoryServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
bind_address = listener.get("bind_address", None)
bind_addresses = listener.get("bind_addresses", [])
if bind_address is not None:
bind_addresses.append(bind_address)
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(

View file

@ -121,8 +121,7 @@ class PusherServer(HomeServer):
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", None)
bind_addresses = listener_config.get("bind_addresses", [])
bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
@ -132,9 +131,6 @@ class PusherServer(HomeServer):
root_resource = create_resource_tree(resources, Resource())
if bind_address is not None:
bind_addresses.append(bind_address)
for address in bind_addresses:
reactor.listenTCP(
port,
@ -154,11 +150,7 @@ class PusherServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
bind_address = listener.get("bind_address", None)
bind_addresses = listener.get("bind_addresses", [])
if bind_address is not None:
bind_addresses.append(bind_address)
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(

View file

@ -289,8 +289,7 @@ class SynchrotronServer(HomeServer):
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", None)
bind_addresses = listener_config.get("bind_addresses", [])
bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
@ -312,9 +311,6 @@ class SynchrotronServer(HomeServer):
root_resource = create_resource_tree(resources, Resource())
if bind_address is not None:
bind_addresses.append(bind_address)
for address in bind_addresses:
reactor.listenTCP(
port,
@ -334,11 +330,7 @@ class SynchrotronServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
bind_address = listener.get("bind_address", None)
bind_addresses = listener.get("bind_addresses", [])
if bind_address is not None:
bind_addresses.append(bind_address)
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(

View file

@ -68,6 +68,9 @@ class EmailConfig(Config):
self.email_notif_for_new_users = email_config.get(
"notif_for_new_users", True
)
self.email_riot_base_url = email_config.get(
"riot_base_url", None
)
if "app_name" in email_config:
self.email_app_name = email_config["app_name"]
else:
@ -85,6 +88,9 @@ class EmailConfig(Config):
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Enable sending emails for notification events
# Defining a custom URL for Riot is only needed if email notifications
# should contain links to a self-hosted installation of Riot; when set
# the "app_name" setting is ignored.
#email:
# enable_notifs: false
# smtp_host: "localhost"
@ -95,4 +101,5 @@ class EmailConfig(Config):
# notif_template_html: notif_mail.html
# notif_template_text: notif_mail.txt
# notif_for_new_users: True
# riot_base_url: "http://localhost/riot"
"""

View file

@ -22,7 +22,6 @@ import yaml
from string import Template
import os
import signal
from synapse.util.debug import debug_deferreds
DEFAULT_LOG_CONFIG = Template("""
@ -71,8 +70,6 @@ class LoggingConfig(Config):
self.verbosity = config.get("verbose", 0)
self.log_config = self.abspath(config.get("log_config"))
self.log_file = self.abspath(config.get("log_file"))
if config.get("full_twisted_stacktraces"):
debug_deferreds()
def default_config(self, config_dir_path, server_name, **kwargs):
log_file = self.abspath("homeserver.log")
@ -88,11 +85,6 @@ class LoggingConfig(Config):
# A yaml python logging config file
log_config: "%(log_config)s"
# Stop twisted from discarding the stack traces of exceptions in
# deferreds by waiting a reactor tick before running a deferred's
# callbacks.
# full_twisted_stacktraces: true
""" % locals()
def read_arguments(self, args):

View file

@ -42,6 +42,15 @@ class ServerConfig(Config):
self.listeners = config.get("listeners", [])
for listener in self.listeners:
bind_address = listener.pop("bind_address", None)
bind_addresses = listener.setdefault("bind_addresses", [])
if bind_address:
bind_addresses.append(bind_address)
elif not bind_addresses:
bind_addresses.append('')
self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
bind_port = config.get("bind_port")
@ -54,7 +63,7 @@ class ServerConfig(Config):
self.listeners.append({
"port": bind_port,
"bind_address": bind_host,
"bind_addresses": [bind_host],
"tls": True,
"type": "http",
"resources": [
@ -73,7 +82,7 @@ class ServerConfig(Config):
if unsecure_port:
self.listeners.append({
"port": unsecure_port,
"bind_address": bind_host,
"bind_addresses": [bind_host],
"tls": False,
"type": "http",
"resources": [
@ -92,7 +101,7 @@ class ServerConfig(Config):
if manhole:
self.listeners.append({
"port": manhole,
"bind_address": "127.0.0.1",
"bind_addresses": ["127.0.0.1"],
"type": "manhole",
})
@ -100,7 +109,7 @@ class ServerConfig(Config):
if metrics_port:
self.listeners.append({
"port": metrics_port,
"bind_address": config.get("metrics_bind_host", "127.0.0.1"),
"bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")],
"tls": False,
"type": "http",
"resources": [

View file

@ -19,7 +19,9 @@ class VoipConfig(Config):
def read_config(self, config):
self.turn_uris = config.get("turn_uris", [])
self.turn_shared_secret = config["turn_shared_secret"]
self.turn_shared_secret = config.get("turn_shared_secret")
self.turn_username = config.get("turn_username")
self.turn_password = config.get("turn_password")
self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"])
def default_config(self, **kwargs):
@ -32,6 +34,11 @@ class VoipConfig(Config):
# The shared secret used to compute passwords for the TURN server
turn_shared_secret: "YOUR_SHARED_SECRET"
# The Username and password if the TURN server needs them and
# does not use a token
#turn_username: "TURNSERVER_USERNAME"
#turn_password: "TURNSERVER_PASSWORD"
# How long generated TURN credentials last
turn_user_lifetime: "1h"
"""

View file

@ -29,3 +29,13 @@ class WorkerConfig(Config):
self.worker_log_file = config.get("worker_log_file")
self.worker_log_config = config.get("worker_log_config")
self.worker_replication_url = config.get("worker_replication_url")
if self.worker_listeners:
for listener in self.worker_listeners:
bind_address = listener.pop("bind_address", None)
bind_addresses = listener.setdefault("bind_addresses", [])
if bind_address:
bind_addresses.append(bind_address)
elif not bind_addresses:
bind_addresses.append('')

678
synapse/event_auth.py Normal file
View file

@ -0,0 +1,678 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException
from unpaddedbase64 import decode_base64
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, SynapseError, EventSizeError
from synapse.types import UserID, get_domain_from_id
logger = logging.getLogger(__name__)
def check(event, auth_events, do_sig_check=True, do_size_check=True):
""" Checks if this event is correctly authed.
Args:
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
Returns:
True if the auth checks pass.
"""
if do_size_check:
_check_size_limits(event)
if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event)
if do_sig_check:
sender_domain = get_domain_from_id(event.sender)
event_id_domain = get_domain_from_id(event.event_id)
is_invite_via_3pid = (
event.type == EventTypes.Member
and event.membership == Membership.INVITE
and "third_party_invite" in event.content
)
# Check the sender's domain has signed the event
if not event.signatures.get(sender_domain):
# We allow invites via 3pid to have a sender from a different
# HS, as the sender must match the sender of the original
# 3pid invite. This is checked further down with the
# other dedicated membership checks.
if not is_invite_via_3pid:
raise AuthError(403, "Event not signed by sender's server")
# Check the event_id's domain has signed the event
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
if auth_events is None:
# Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now)
logger.warn("Trusting event: %s", event.event_id)
return True
if event.type == EventTypes.Create:
room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain:
raise AuthError(
403,
"Creation event's room_id domain does not match sender's"
)
# FIXME
return True
creation_event = auth_events.get((EventTypes.Create, ""), None)
if not creation_event:
raise SynapseError(
403,
"Room %r does not exist" % (event.room_id,)
)
creating_domain = get_domain_from_id(event.room_id)
originating_domain = get_domain_from_id(event.sender)
if creating_domain != originating_domain:
if not _can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
# FIXME: Temp hack
if event.type == EventTypes.Aliases:
if not event.is_state():
raise AuthError(
403,
"Alias event must be a state event",
)
if not event.state_key:
raise AuthError(
403,
"Alias event must have non-empty state_key"
)
sender_domain = get_domain_from_id(event.sender)
if event.state_key != sender_domain:
raise AuthError(
403,
"Alias event's state_key does not match sender's domain"
)
return True
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Auth events: %s",
[a.event_id for a in auth_events.values()]
)
if event.type == EventTypes.Member:
allowed = _is_membership_change_allowed(
event, auth_events
)
if allowed:
logger.debug("Allowing! %s", event)
else:
logger.debug("Denying! %s", event)
return allowed
_check_event_sender_in_room(event, auth_events)
# Special case to allow m.room.third_party_invite events wherever
# a user is allowed to issue invites. Fixes
# https://github.com/vector-im/vector-web/issues/1208 hopefully
if event.type == EventTypes.ThirdPartyInvite:
user_level = get_user_power_level(event.user_id, auth_events)
invite_level = _get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, (
"You cannot issue a third party invite for %s." %
(event.content.display_name,)
)
)
else:
return True
_can_send_event(event, auth_events)
if event.type == EventTypes.PowerLevels:
_check_power_levels(event, auth_events)
if event.type == EventTypes.Redaction:
check_redaction(event, auth_events)
logger.debug("Allowing! %s", event)
def _check_size_limits(event):
def too_big(field):
raise EventSizeError("%s too large" % (field,))
if len(event.user_id) > 255:
too_big("user_id")
if len(event.room_id) > 255:
too_big("room_id")
if event.is_state() and len(event.state_key) > 255:
too_big("state_key")
if len(event.type) > 255:
too_big("type")
if len(event.event_id) > 255:
too_big("event_id")
if len(encode_canonical_json(event.get_pdu_json())) > 65536:
too_big("event")
def _can_federate(event, auth_events):
creation_event = auth_events.get((EventTypes.Create, ""))
return creation_event.content.get("m.federate", True) is True
def _is_membership_change_allowed(event, auth_events):
membership = event.content["membership"]
# Check if this is the room creator joining:
if len(event.prev_events) == 1 and Membership.JOIN == membership:
# Get room creation event:
key = (EventTypes.Create, "", )
create = auth_events.get(key)
if create and event.prev_events[0][0] == create.event_id:
if create.content["creator"] == event.state_key:
return True
target_user_id = event.state_key
creating_domain = get_domain_from_id(event.room_id)
target_domain = get_domain_from_id(target_user_id)
if creating_domain != target_domain:
if not _can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
# get info about the caller
key = (EventTypes.Member, event.user_id, )
caller = auth_events.get(key)
caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE
# get info about the target
key = (EventTypes.Member, target_user_id, )
target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN
target_banned = target and target.membership == Membership.BAN
key = (EventTypes.JoinRules, "", )
join_rule_event = auth_events.get(key)
if join_rule_event:
join_rule = join_rule_event.content.get(
"join_rule", JoinRules.INVITE
)
else:
join_rule = JoinRules.INVITE
user_level = get_user_power_level(event.user_id, auth_events)
target_level = get_user_power_level(
target_user_id, auth_events
)
# FIXME (erikj): What should we do here as the default?
ban_level = _get_named_level(auth_events, "ban", 50)
logger.debug(
"_is_membership_change_allowed: %s",
{
"caller_in_room": caller_in_room,
"caller_invited": caller_invited,
"target_banned": target_banned,
"target_in_room": target_in_room,
"membership": membership,
"join_rule": join_rule,
"target_user_id": target_user_id,
"event.user_id": event.user_id,
}
)
if Membership.INVITE == membership and "third_party_invite" in event.content:
if not _verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.")
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
return True
if Membership.JOIN != membership:
if (caller_invited
and Membership.LEAVE == membership
and target_user_id == event.user_id):
return True
if not caller_in_room: # caller isn't joined
raise AuthError(
403,
"%s not in room %s." % (event.user_id, event.room_id,)
)
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
# PRIVATE join rules.
# Invites are valid iff caller is in the room and target isn't.
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." %
target_user_id)
else:
invite_level = _get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, "You cannot invite user %s." % target_user_id
)
elif Membership.JOIN == membership:
# Joins are valid iff caller == target and they were:
# invited: They are accepting the invitation
# joined: It's a NOOP
if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.")
elif target_banned:
raise AuthError(403, "You are banned from this room")
elif join_rule == JoinRules.PUBLIC:
pass
elif join_rule == JoinRules.INVITE:
if not caller_in_room and not caller_invited:
raise AuthError(403, "You are not invited to this room.")
else:
# TODO (erikj): may_join list
# TODO (erikj): private rooms
raise AuthError(403, "You are not allowed to join this room")
elif Membership.LEAVE == membership:
# TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level:
raise AuthError(
403, "You cannot unban user &s." % (target_user_id,)
)
elif target_user_id != event.user_id:
kick_level = _get_named_level(auth_events, "kick", 50)
if user_level < kick_level or user_level <= target_level:
raise AuthError(
403, "You cannot kick user %s." % target_user_id
)
elif Membership.BAN == membership:
if user_level < ban_level or user_level <= target_level:
raise AuthError(403, "You don't have permission to ban")
else:
raise AuthError(500, "Unknown membership %s" % membership)
return True
def _check_event_sender_in_room(event, auth_events):
key = (EventTypes.Member, event.user_id, )
member_event = auth_events.get(key)
return _check_joined_room(
member_event,
event.user_id,
event.room_id
)
def _check_joined_room(member, user_id, room_id):
if not member or member.membership != Membership.JOIN:
raise AuthError(403, "User %s not in room %s (%s)" % (
user_id, room_id, repr(member)
))
def get_send_level(etype, state_key, auth_events):
key = (EventTypes.PowerLevels, "", )
send_level_event = auth_events.get(key)
send_level = None
if send_level_event:
send_level = send_level_event.content.get("events", {}).get(
etype
)
if send_level is None:
if state_key is not None:
send_level = send_level_event.content.get(
"state_default", 50
)
else:
send_level = send_level_event.content.get(
"events_default", 0
)
if send_level:
send_level = int(send_level)
else:
send_level = 0
return send_level
def _can_send_event(event, auth_events):
send_level = get_send_level(
event.type, event.get("state_key", None), auth_events
)
user_level = get_user_power_level(event.user_id, auth_events)
if user_level < send_level:
raise AuthError(
403,
"You don't have permission to post that to the room. " +
"user_level (%d) < send_level (%d)" % (user_level, send_level)
)
# Check state_key
if hasattr(event, "state_key"):
if event.state_key.startswith("@"):
if event.state_key != event.user_id:
raise AuthError(
403,
"You are not allowed to set others state"
)
return True
def check_redaction(event, auth_events):
"""Check whether the event sender is allowed to redact the target event.
Returns:
True if the the sender is allowed to redact the target event if the
target event was created by them.
False if the sender is allowed to redact the target event with no
further checks.
Raises:
AuthError if the event sender is definitely not allowed to redact
the target event.
"""
user_level = get_user_power_level(event.user_id, auth_events)
redact_level = _get_named_level(auth_events, "redact", 50)
if user_level >= redact_level:
return False
redacter_domain = get_domain_from_id(event.event_id)
redactee_domain = get_domain_from_id(event.redacts)
if redacter_domain == redactee_domain:
return True
raise AuthError(
403,
"You don't have permission to redact events"
)
def _check_power_levels(event, auth_events):
user_list = event.content.get("users", {})
# Validate users
for k, v in user_list.items():
try:
UserID.from_string(k)
except:
raise SynapseError(400, "Not a valid user_id: %s" % (k,))
try:
int(v)
except:
raise SynapseError(400, "Not a valid power level: %s" % (v,))
key = (event.type, event.state_key, )
current_state = auth_events.get(key)
if not current_state:
return
user_level = get_user_power_level(event.user_id, auth_events)
# Check other levels:
levels_to_check = [
("users_default", None),
("events_default", None),
("state_default", None),
("ban", None),
("redact", None),
("kick", None),
("invite", None),
]
old_list = current_state.content.get("users")
for user in set(old_list.keys() + user_list.keys()):
levels_to_check.append(
(user, "users")
)
old_list = current_state.content.get("events")
new_list = event.content.get("events")
for ev_id in set(old_list.keys() + new_list.keys()):
levels_to_check.append(
(ev_id, "events")
)
old_state = current_state.content
new_state = event.content
for level_to_check, dir in levels_to_check:
old_loc = old_state
new_loc = new_state
if dir:
old_loc = old_loc.get(dir, {})
new_loc = new_loc.get(dir, {})
if level_to_check in old_loc:
old_level = int(old_loc[level_to_check])
else:
old_level = None
if level_to_check in new_loc:
new_level = int(new_loc[level_to_check])
else:
new_level = None
if new_level is not None and old_level is not None:
if new_level == old_level:
continue
if dir == "users" and level_to_check != event.user_id:
if old_level == user_level:
raise AuthError(
403,
"You don't have permission to remove ops level equal "
"to your own"
)
if old_level > user_level or new_level > user_level:
raise AuthError(
403,
"You don't have permission to add ops level greater "
"than your own"
)
def _get_power_level_event(auth_events):
key = (EventTypes.PowerLevels, "", )
return auth_events.get(key)
def get_user_power_level(user_id, auth_events):
power_level_event = _get_power_level_event(auth_events)
if power_level_event:
level = power_level_event.content.get("users", {}).get(user_id)
if not level:
level = power_level_event.content.get("users_default", 0)
if level is None:
return 0
else:
return int(level)
else:
key = (EventTypes.Create, "", )
create_event = auth_events.get(key)
if (create_event is not None and
create_event.content["creator"] == user_id):
return 100
else:
return 0
def _get_named_level(auth_events, name, default):
power_level_event = _get_power_level_event(auth_events)
if not power_level_event:
return default
level = power_level_event.content.get(name, None)
if level is not None:
return int(level)
else:
return default
def _verify_third_party_invite(event, auth_events):
"""
Validates that the invite event is authorized by a previous third-party invite.
Checks that the public key, and keyserver, match those in the third party invite,
and that the invite event has a signature issued using that public key.
Args:
event: The m.room.member join event being validated.
auth_events: All relevant previous context events which may be used
for authorization decisions.
Return:
True if the event fulfills the expectations of a previous third party
invite event.
"""
if "third_party_invite" not in event.content:
return False
if "signed" not in event.content["third_party_invite"]:
return False
signed = event.content["third_party_invite"]["signed"]
for key in {"mxid", "token"}:
if key not in signed:
return False
token = signed["token"]
invite_event = auth_events.get(
(EventTypes.ThirdPartyInvite, token,)
)
if not invite_event:
return False
if invite_event.sender != event.sender:
return False
if event.user_id != invite_event.user_id:
return False
if signed["mxid"] != event.state_key:
return False
if signed["token"] != token:
return False
for public_key_object in get_public_keys(invite_event):
public_key = public_key_object["public_key"]
try:
for server, signature_block in signed["signatures"].items():
for key_name, encoded_signature in signature_block.items():
if not key_name.startswith("ed25519:"):
continue
verify_key = decode_verify_key_bytes(
key_name,
decode_base64(public_key)
)
verify_signed_json(signed, server, verify_key)
# We got the public key from the invite, so we know that the
# correct server signed the signed bundle.
# The caller is responsible for checking that the signing
# server has not revoked that public key.
return True
except (KeyError, SignatureVerifyException,):
continue
return False
def get_public_keys(invite_event):
public_keys = []
if "public_key" in invite_event.content:
o = {
"public_key": invite_event.content["public_key"],
}
if "key_validity_url" in invite_event.content:
o["key_validity_url"] = invite_event.content["key_validity_url"]
public_keys.append(o)
public_keys.extend(invite_event.content.get("public_keys", []))
return public_keys
def auth_types_for_event(event):
"""Given an event, return a list of (EventType, StateKey) that may be
needed to auth the event. The returned list may be a superset of what
would actually be required depending on the full state of the room.
Used to limit the number of events to fetch from the database to
actually auth the event.
"""
if event.type == EventTypes.Create:
return []
auth_types = []
auth_types.append((EventTypes.PowerLevels, "", ))
auth_types.append((EventTypes.Member, event.user_id, ))
auth_types.append((EventTypes.Create, "", ))
if event.type == EventTypes.Member:
membership = event.content["membership"]
if membership in [Membership.JOIN, Membership.INVITE]:
auth_types.append((EventTypes.JoinRules, "", ))
auth_types.append((EventTypes.Member, event.state_key, ))
if membership == Membership.INVITE:
if "third_party_invite" in event.content:
key = (
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"]
)
auth_types.append(key)
return auth_types

View file

@ -79,7 +79,6 @@ class EventBase(object):
auth_events = _event_dict_property("auth_events")
depth = _event_dict_property("depth")
content = _event_dict_property("content")
event_id = _event_dict_property("event_id")
hashes = _event_dict_property("hashes")
origin = _event_dict_property("origin")
origin_server_ts = _event_dict_property("origin_server_ts")
@ -88,8 +87,6 @@ class EventBase(object):
redacts = _event_dict_property("redacts")
room_id = _event_dict_property("room_id")
sender = _event_dict_property("sender")
state_key = _event_dict_property("state_key")
type = _event_dict_property("type")
user_id = _event_dict_property("sender")
@property
@ -162,6 +159,11 @@ class FrozenEvent(EventBase):
else:
frozen_dict = event_dict
self.event_id = event_dict["event_id"]
self.type = event_dict["type"]
if "state_key" in event_dict:
self.state_key = event_dict["state_key"]
super(FrozenEvent, self).__init__(
frozen_dict,
signatures=signatures,

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import EventBase, FrozenEvent
from . import EventBase, FrozenEvent, _event_dict_property
from synapse.types import EventID
@ -34,6 +34,10 @@ class EventBuilder(EventBase):
internal_metadata_dict=internal_metadata_dict,
)
event_id = _event_dict_property("event_id")
state_key = _event_dict_property("state_key")
type = _event_dict_property("type")
def build(self):
return FrozenEvent.from_event(self)

View file

@ -26,7 +26,7 @@ from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.events import FrozenEvent
from synapse.events import FrozenEvent, builder
import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
@ -499,8 +499,10 @@ class FederationClient(FederationBase):
if "prev_state" not in pdu_dict:
pdu_dict["prev_state"] = []
ev = builder.EventBuilder(pdu_dict)
defer.returnValue(
(destination, self.event_from_pdu_json(pdu_dict))
(destination, ev)
)
break
except CodeMessageException as e:

View file

@ -362,7 +362,7 @@ class TransactionQueue(object):
if not success:
break
except NotRetryingDestination:
logger.info(
logger.debug(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,

View file

@ -591,12 +591,12 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys())
logger.info("calling resolve_state_groups in _maybe_backfill")
logger.debug("calling resolve_state_groups in _maybe_backfill")
states = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
for e in event_ids
]))
states = dict(zip(event_ids, [s[1] for s in states]))
states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events(
[e_id for ids in states.values() for e_id in ids],
@ -1530,7 +1530,7 @@ class FederationHandler(BaseHandler):
(d.type, d.state_key): d for d in different_events if d
})
new_state, prev_state = self.state_handler.resolve_events(
new_state = self.state_handler.resolve_events(
[local_view.values(), remote_view.values()],
event
)

View file

@ -439,15 +439,23 @@ class Mailer(object):
})
def make_room_link(self, room_id):
# need /beta for Universal Links to work on iOS
if self.app_name == "Vector":
return "https://vector.im/beta/#/room/%s" % (room_id,)
if self.hs.config.email_riot_base_url:
base_url = self.hs.config.email_riot_base_url
elif self.app_name == "Vector":
# need /beta for Universal Links to work on iOS
base_url = "https://vector.im/beta/#/room"
else:
return "https://matrix.to/#/%s" % (room_id,)
base_url = "https://matrix.to/#"
return "%s/%s" % (base_url, room_id)
def make_notif_link(self, notif):
# need /beta for Universal Links to work on iOS
if self.app_name == "Vector":
if self.hs.config.email_riot_base_url:
return "%s/#/room/%s/%s" % (
self.hs.config.email_riot_base_url,
notif['room_id'], notif['event_id']
)
elif self.app_name == "Vector":
# need /beta for Universal Links to work on iOS
return "https://vector.im/beta/#/room/%s/%s" % (
notif['room_id'], notif['event_id']
)

View file

@ -52,7 +52,7 @@ def get_badge_count(store, user_id):
def get_context_for_event(store, state_handler, ev, user_id):
ctx = {}
room_state_ids = yield state_handler.get_current_state_ids(ev.room_id)
room_state_ids = yield store.get_state_ids_for_event(ev.event_id)
# we no longer bother setting room_alias, and make room_name the
# human-readable name instead, be that m.room.name, an alias or

View file

@ -86,7 +86,11 @@ class HttpTransactionCache(object):
pass # execute the function instead.
deferred = fn(*args, **kwargs)
observable = ObservableDeferred(deferred)
# We don't add an errback to the raw deferred, so we ask ObservableDeferred
# to swallow the error. This is fine as the error will still be reported
# to the observers.
observable = ObservableDeferred(deferred, consumeErrors=True)
self.transactions[txn_key] = (observable, self.clock.time_msec())
return observable.observe()

View file

@ -118,8 +118,14 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def do_password_login(self, login_submission):
if 'medium' in login_submission and 'address' in login_submission:
address = login_submission['address']
if login_submission['medium'] == 'email':
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
address = address.lower()
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
login_submission['medium'], login_submission['address']
login_submission['medium'], address
)
if not user_id:
raise LoginError(403, "", errcode=Codes.FORBIDDEN)

View file

@ -32,19 +32,27 @@ class VoipRestServlet(ClientV1RestServlet):
turnUris = self.hs.config.turn_uris
turnSecret = self.hs.config.turn_shared_secret
turnUsername = self.hs.config.turn_username
turnPassword = self.hs.config.turn_password
userLifetime = self.hs.config.turn_user_lifetime
if not turnUris or not turnSecret or not userLifetime:
if turnUris and turnSecret and userLifetime:
expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000
username = "%d:%s" % (expiry, requester.user.to_string())
mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1)
# We need to use standard padded base64 encoding here
# encode_base64 because we need to add the standard padding to get the
# same result as the TURN server.
password = base64.b64encode(mac.digest())
elif turnUris and turnUsername and turnPassword and userLifetime:
username = turnUsername
password = turnPassword
else:
defer.returnValue((200, {}))
expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000
username = "%d:%s" % (expiry, requester.user.to_string())
mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1)
# We need to use standard padded base64 encoding here
# encode_base64 because we need to add the standard padding to get the
# same result as the TURN server.
password = base64.b64encode(mac.digest())
defer.returnValue((200, {
'username': username,
'password': password,

View file

@ -96,6 +96,11 @@ class PasswordRestServlet(RestServlet):
threepid = result[LoginType.EMAIL_IDENTITY]
if 'medium' not in threepid or 'address' not in threepid:
raise SynapseError(500, "Malformed threepid")
if threepid['medium'] == 'email':
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
threepid['address'] = threepid['address'].lower()
# if using email, we must know about the email they're authing with!
threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
threepid['medium'], threepid['address']

View file

@ -16,12 +16,12 @@
from twisted.internet import defer
from synapse import event_auth
from synapse.util.logutils import log_function
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.auth import AuthEventTypes
from synapse.events.snapshot import EventContext
from synapse.util.async import Linearizer
@ -41,12 +41,14 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR)
SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR)
EVICTION_TIMEOUT_SECONDS = 60 * 60
_NEXT_STATE_ID = 1
POWER_KEY = (EventTypes.PowerLevels, "")
def _gen_state_id():
global _NEXT_STATE_ID
@ -77,6 +79,9 @@ class _StateCacheEntry(object):
else:
self.state_id = _gen_state_id()
def __len__(self):
return len(self.state)
class StateHandler(object):
""" Responsible for doing state conflict resolution.
@ -99,6 +104,7 @@ class StateHandler(object):
clock=self.clock,
max_len=SIZE_OF_CACHE,
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True,
reset_expiry_on_get=True,
)
@ -123,7 +129,7 @@ class StateHandler(object):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.info("calling resolve_state_groups from get_current_state")
logger.debug("calling resolve_state_groups from get_current_state")
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
state = ret.state
@ -148,7 +154,7 @@ class StateHandler(object):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.info("calling resolve_state_groups from get_current_state_ids")
logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
state = ret.state
@ -162,7 +168,7 @@ class StateHandler(object):
def get_current_user_in_room(self, room_id, latest_event_ids=None):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.info("calling resolve_state_groups from get_current_user_in_room")
logger.debug("calling resolve_state_groups from get_current_user_in_room")
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(
room_id, entry.state_id, entry.state
@ -226,7 +232,7 @@ class StateHandler(object):
context.prev_state_events = []
defer.returnValue(context)
logger.info("calling resolve_state_groups from compute_event_context")
logger.debug("calling resolve_state_groups from compute_event_context")
if event.is_state():
entry = yield self.resolve_state_groups(
event.room_id, [e for e, _ in event.prev_events],
@ -327,20 +333,13 @@ class StateHandler(object):
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
state_map = yield self.store.get_events(
[e_id for st in state_groups_ids.values() for e_id in st.values()],
get_prev_content=False
)
state_sets = [
[state_map[e_id] for key, e_id in st.items() if e_id in state_map]
for st in state_groups_ids.values()
]
new_state, _ = self._resolve_events(
state_sets, event_type, state_key
)
new_state = {
key: e.event_id for key, e in new_state.items()
}
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events(
state_groups_ids.values(),
state_map_factory=lambda ev_ids: self.store.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
),
)
else:
new_state = {
key: e_ids.pop() for key, e_ids in state.items()
@ -388,152 +387,264 @@ class StateHandler(object):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
if event.is_state():
return self._resolve_events(
state_sets, event.type, event.state_key
)
else:
return self._resolve_events(state_sets)
state_set_ids = [{
(ev.type, ev.state_key): ev.event_id
for ev in st
} for st in state_sets]
state_map = {
ev.event_id: ev
for st in state_sets
for ev in st
}
def _resolve_events(self, state_sets, event_type=None, state_key=""):
"""
Returns
(dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple
(new_state, prev_states). new_state is a map from (type, state_key)
to event. prev_states is a list of event_ids.
"""
with Measure(self.clock, "state._resolve_events"):
state = {}
for st in state_sets:
for e in st:
state.setdefault(
(e.type, e.state_key),
{}
)[e.event_id] = e
new_state = resolve_events(state_set_ids, state_map)
unconflicted_state = {
k: v.values()[0] for k, v in state.items()
if len(v.values()) == 1
}
new_state = {
key: state_map[ev_id] for key, ev_id in new_state.items()
}
conflicted_state = {
k: v.values()
for k, v in state.items()
if len(v.values()) > 1
}
return new_state
if event_type:
prev_states_events = conflicted_state.get(
(event_type, state_key), []
)
prev_states = [s.event_id for s in prev_states_events]
else:
prev_states = []
auth_events = {
k: e for k, e in unconflicted_state.items()
if k[0] in AuthEventTypes
}
def _ordered_events(events):
def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
try:
resolved_state = self._resolve_state_events(
conflicted_state, auth_events
)
except:
logger.exception("Failed to resolve state")
raise
return sorted(events, key=key_func)
new_state = unconflicted_state
new_state.update(resolved_state)
return new_state, prev_states
def resolve_events(state_sets, state_map_factory):
"""
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
state_map_factory(dict|callable): If callable, then will be called
with a list of event_ids that are needed, and should return with
a Deferred of dict of event_id to event. Otherwise, should be
a dict from event_id to event of all events in state_sets.
@log_function
def _resolve_state_events(self, conflicted_state, auth_events):
""" This is where we actually decide which of the conflicted state to
use.
Returns
dict[(str, str), synapse.events.FrozenEvent] is a map from
(type, state_key) to event.
"""
unconflicted_state, conflicted_state = _seperate(
state_sets,
)
We resolve conflicts in the following order:
1. power levels
2. join rules
3. memberships
4. other events.
"""
resolved_state = {}
power_key = (EventTypes.PowerLevels, "")
if power_key in conflicted_state:
events = conflicted_state[power_key]
logger.debug("Resolving conflicted power levels %r", events)
resolved_state[power_key] = self._resolve_auth_events(
events, auth_events)
if callable(state_map_factory):
return _resolve_with_state_fac(
unconflicted_state, conflicted_state, state_map_factory
)
auth_events.update(resolved_state)
state_map = state_map_factory
for key, events in conflicted_state.items():
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = self._resolve_auth_events(
events,
auth_events
)
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
auth_events.update(resolved_state)
return _resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
)
for key, events in conflicted_state.items():
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = self._resolve_auth_events(
events,
auth_events
)
auth_events.update(resolved_state)
def _seperate(state_sets):
"""Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated
with them in different state sets.
"""
unconflicted_state = dict(state_sets[0])
conflicted_state = {}
for key, events in conflicted_state.items():
if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = self._resolve_normal_events(
events, auth_events
)
for state_set in state_sets[1:]:
for key, value in state_set.iteritems():
# Check if there is an unconflicted entry for the state key.
unconflicted_value = unconflicted_state.get(key)
if unconflicted_value is None:
# There isn't an unconflicted entry so check if there is a
# conflicted entry.
ls = conflicted_state.get(key)
if ls is None:
# There wasn't a conflicted entry so haven't seen this key before.
# Therefore it isn't conflicted yet.
unconflicted_state[key] = value
else:
# This key is already conflicted, add our value to the conflict set.
ls.add(value)
elif unconflicted_value != value:
# If the unconflicted value is not the same as our value then we
# have a new conflict. So move the key from the unconflicted_state
# to the conflicted state.
conflicted_state[key] = {value, unconflicted_value}
unconflicted_state.pop(key, None)
return resolved_state
return unconflicted_state, conflicted_state
def _resolve_auth_events(self, events, auth_events):
reverse = [i for i in reversed(self._ordered_events(events))]
auth_events = dict(auth_events)
@defer.inlineCallbacks
def _resolve_with_state_fac(unconflicted_state, conflicted_state,
state_map_factory):
needed_events = set(
event_id
for event_ids in conflicted_state.itervalues()
for event_id in event_ids
)
prev_event = reverse[0]
for event in reverse[1:]:
auth_events[(prev_event.type, prev_event.state_key)] = prev_event
try:
# FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps.
# The signatures have already been checked at this point
self.hs.get_auth().check(event, auth_events, do_sig_check=False)
prev_event = event
except AuthError:
return prev_event
logger.info("Asking for %d conflicted events", len(needed_events))
return event
state_map = yield state_map_factory(needed_events)
def _resolve_normal_events(self, events, auth_events):
for event in self._ordered_events(events):
try:
# FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps.
# The signatures have already been checked at this point
self.hs.get_auth().check(event, auth_events, do_sig_check=False)
return event
except AuthError:
pass
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
# Use the last event (the one with the least depth) if they all fail
# the auth check.
return event
new_needed_events = set(auth_events.itervalues())
new_needed_events -= needed_events
def _ordered_events(self, events):
def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
logger.info("Asking for %d auth events", len(new_needed_events))
return sorted(events, key=key_func)
state_map_new = yield state_map_factory(new_needed_events)
state_map.update(state_map_new)
defer.returnValue(_resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
))
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
auth_events = {}
for event_ids in conflicted_state.itervalues():
for event_id in event_ids:
if event_id in state_map:
keys = event_auth.auth_types_for_event(state_map[event_id])
for key in keys:
if key not in auth_events:
event_id = unconflicted_state.get(key, None)
if event_id:
auth_events[key] = event_id
return auth_events
def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids,
state_map):
conflicted_state = {}
for key, event_ids in conflicted_state_ds.iteritems():
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
if len(events) > 1:
conflicted_state[key] = events
elif len(events) == 1:
unconflicted_state_ids[key] = events[0].event_id
auth_events = {
key: state_map[ev_id]
for key, ev_id in auth_event_ids.items()
if ev_id in state_map
}
try:
resolved_state = _resolve_state_events(
conflicted_state, auth_events
)
except:
logger.exception("Failed to resolve state")
raise
new_state = unconflicted_state_ids
for key, event in resolved_state.iteritems():
new_state[key] = event.event_id
return new_state
def _resolve_state_events(conflicted_state, auth_events):
""" This is where we actually decide which of the conflicted state to
use.
We resolve conflicts in the following order:
1. power levels
2. join rules
3. memberships
4. other events.
"""
resolved_state = {}
if POWER_KEY in conflicted_state:
events = conflicted_state[POWER_KEY]
logger.debug("Resolving conflicted power levels %r", events)
resolved_state[POWER_KEY] = _resolve_auth_events(
events, auth_events)
auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = _resolve_auth_events(
events,
auth_events
)
auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = _resolve_auth_events(
events,
auth_events
)
auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = _resolve_normal_events(
events, auth_events
)
return resolved_state
def _resolve_auth_events(events, auth_events):
reverse = [i for i in reversed(_ordered_events(events))]
auth_keys = set(
key
for event in events
for key in event_auth.auth_types_for_event(event)
)
new_auth_events = {}
for key in auth_keys:
auth_event = auth_events.get(key, None)
if auth_event:
new_auth_events[key] = auth_event
auth_events = new_auth_events
prev_event = reverse[0]
for event in reverse[1:]:
auth_events[(prev_event.type, prev_event.state_key)] = prev_event
try:
# The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
prev_event = event
except AuthError:
return prev_event
return event
def _resolve_normal_events(events, auth_events):
for event in _ordered_events(events):
try:
# The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
return event
except AuthError:
pass
# Use the last event (the one with the least depth) if they all fail
# the auth check.
return event

View file

@ -169,7 +169,7 @@ class SQLBaseStore(object):
max_entries=hs.config.event_cache_size)
self._state_group_cache = DictionaryCache(
"*stateGroupCache*", 2000 * CACHE_SIZE_FACTOR
"*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
)
self._event_fetch_lock = threading.Condition()

View file

@ -18,13 +18,29 @@ import ujson
from twisted.internet import defer
from ._base import SQLBaseStore
from .background_updates import BackgroundUpdateStore
logger = logging.getLogger(__name__)
class DeviceInboxStore(SQLBaseStore):
class DeviceInboxStore(BackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, hs):
super(DeviceInboxStore, self).__init__(hs)
self.register_background_index_update(
"device_inbox_stream_index",
index_name="device_inbox_stream_id_user_id",
table="device_inbox",
columns=["stream_id", "user_id"],
)
self.register_background_update_handler(
self.DEVICE_INBOX_STREAM_ID,
self._background_drop_index_device_inbox,
)
@defer.inlineCallbacks
def add_messages_to_device_inbox(self, local_messages_by_user_then_device,
@ -368,3 +384,18 @@ class DeviceInboxStore(SQLBaseStore):
"delete_device_msgs_for_remote",
delete_messages_for_remote_destination_txn
)
@defer.inlineCallbacks
def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
txn.execute(
"DROP INDEX IF EXISTS device_inbox_stream_id"
)
txn.close()
yield self.runWithConnection(reindex_txn)
yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
defer.returnValue(1)

View file

@ -1090,10 +1090,10 @@ class EventsStore(SQLBaseStore):
self._do_fetch
)
logger.info("Loading %d events", len(events))
logger.debug("Loading %d events", len(events))
with PreserveLoggingContext():
rows = yield events_d
logger.info("Loaded %d events (%d rows)", len(events), len(rows))
logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]]

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 39
SCHEMA_VERSION = 40
dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -390,7 +390,8 @@ class RoomMemberStore(SQLBaseStore):
room_id, state_group, state_ids,
)
@cachedInlineCallbacks(num_args=2, cache_context=True)
@cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
max_entries=100000)
def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
cache_context, event=None):
# We don't use `state_group`, it's there so that we can cache based

View file

@ -0,0 +1,21 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- turn the pre-fill startup query into a index-only scan on postgresql.
INSERT into background_updates (update_name, progress_json)
VALUES ('device_inbox_stream_index', '{}');
INSERT into background_updates (update_name, progress_json, depends_on)
VALUES ('device_inbox_stream_drop', '{}', 'device_inbox_stream_index');

View file

@ -284,7 +284,7 @@ class StateStore(SQLBaseStore):
return [r[0] for r in results]
return self.runInteraction("get_current_state_for_key", f)
@cached(num_args=2, max_entries=1000)
@cached(num_args=2, max_entries=100000, iterable=True)
def _get_state_group_from_group(self, group, types):
raise NotImplementedError()

View file

@ -40,8 +40,8 @@ def register_cache(name, cache):
)
_string_cache = LruCache(int(5000 * CACHE_SIZE_FACTOR))
caches_by_name["string_cache"] = _string_cache
_string_cache = LruCache(int(100000 * CACHE_SIZE_FACTOR))
_stirng_cache_metrics = register_cache("string_cache", _string_cache)
KNOWN_KEYS = {
@ -69,7 +69,12 @@ KNOWN_KEYS = {
def intern_string(string):
"""Takes a (potentially) unicode string and interns using custom cache
"""
return _string_cache.setdefault(string, string)
new_str = _string_cache.setdefault(string, string)
if new_str is string:
_stirng_cache_metrics.inc_hits()
else:
_stirng_cache_metrics.inc_misses()
return new_str
def intern_dict(dictionary):

View file

@ -17,7 +17,7 @@ import logging
from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
)
@ -42,6 +42,25 @@ _CacheSentinel = object()
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
class CacheEntry(object):
__slots__ = [
"deferred", "sequence", "callbacks", "invalidated"
]
def __init__(self, deferred, sequence, callbacks):
self.deferred = deferred
self.sequence = sequence
self.callbacks = set(callbacks)
self.invalidated = False
def invalidate(self):
if not self.invalidated:
self.invalidated = True
for callback in self.callbacks:
callback()
self.callbacks.clear()
class Cache(object):
__slots__ = (
"cache",
@ -51,12 +70,16 @@ class Cache(object):
"sequence",
"thread",
"metrics",
"_pending_deferred_cache",
)
def __init__(self, name, max_entries=1000, keylen=1, tree=False):
def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False):
cache_type = TreeCache if tree else dict
self._pending_deferred_cache = cache_type()
self.cache = LruCache(
max_size=max_entries, keylen=keylen, cache_type=cache_type
max_size=max_entries, keylen=keylen, cache_type=cache_type,
size_callback=(lambda d: len(d.result)) if iterable else None,
)
self.name = name
@ -76,7 +99,15 @@ class Cache(object):
)
def get(self, key, default=_CacheSentinel, callback=None):
val = self.cache.get(key, _CacheSentinel, callback=callback)
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
if val.sequence == self.sequence:
val.callbacks.update(callbacks)
self.metrics.inc_hits()
return val.deferred
val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
if val is not _CacheSentinel:
self.metrics.inc_hits()
return val
@ -88,15 +119,39 @@ class Cache(object):
else:
return default
def update(self, sequence, key, value, callback=None):
def set(self, key, value, callback=None):
callbacks = [callback] if callback else []
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
self.prefill(key, value, callback=callback)
entry = CacheEntry(
deferred=value,
sequence=self.sequence,
callbacks=callbacks,
)
entry.callbacks.update(callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()
self._pending_deferred_cache[key] = entry
def shuffle(result):
if self.sequence == entry.sequence:
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
self.cache.set(key, entry.deferred, entry.callbacks)
else:
entry.invalidate()
else:
entry.invalidate()
return result
entry.deferred.addCallback(shuffle)
def prefill(self, key, value, callback=None):
self.cache.set(key, value, callback=callback)
callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks)
def invalidate(self, key):
self.check_thread()
@ -108,6 +163,10 @@ class Cache(object):
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
entry = self._pending_deferred_cache.pop(key, None)
if entry:
entry.invalidate()
self.cache.pop(key, None)
def invalidate_many(self, key):
@ -119,6 +178,11 @@ class Cache(object):
self.sequence += 1
self.cache.del_multi(key)
entry_dict = self._pending_deferred_cache.pop(key, None)
if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict):
entry.invalidate()
def invalidate_all(self):
self.check_thread()
self.sequence += 1
@ -155,7 +219,7 @@ class CacheDescriptor(object):
"""
def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
inlineCallbacks=False, cache_context=False):
inlineCallbacks=False, cache_context=False, iterable=False):
max_entries = int(max_entries * CACHE_SIZE_FACTOR)
self.orig = orig
@ -169,6 +233,8 @@ class CacheDescriptor(object):
self.num_args = num_args
self.tree = tree
self.iterable = iterable
all_args = inspect.getargspec(orig)
self.arg_names = all_args.args[1:num_args + 1]
@ -203,6 +269,7 @@ class CacheDescriptor(object):
max_entries=self.max_entries,
keylen=self.num_args,
tree=self.tree,
iterable=self.iterable,
)
@functools.wraps(self.orig)
@ -243,11 +310,6 @@ class CacheDescriptor(object):
return preserve_context_over_deferred(observer)
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
sequence = cache.sequence
ret = defer.maybeDeferred(
preserve_context_over_fn,
self.function_to_call,
@ -261,7 +323,7 @@ class CacheDescriptor(object):
ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True)
cache.update(sequence, cache_key, ret, callback=invalidate_callback)
cache.set(cache_key, ret, callback=invalidate_callback)
return preserve_context_over_deferred(ret.observe())
@ -359,7 +421,6 @@ class CacheListDescriptor(object):
missing.append(arg)
if missing:
sequence = cache.sequence
args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing
@ -382,8 +443,8 @@ class CacheListDescriptor(object):
key = list(keyargs)
key[self.list_pos] = arg
cache.update(
sequence, tuple(key), observer,
cache.set(
tuple(key), observer,
callback=invalidate_callback
)
@ -421,17 +482,20 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
self.cache.invalidate(self.key)
def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
iterable=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
tree=tree,
cache_context=cache_context,
iterable=iterable,
)
def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False,
iterable=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
@ -439,6 +503,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
tree=tree,
inlineCallbacks=True,
cache_context=cache_context,
iterable=iterable,
)

View file

@ -23,7 +23,9 @@ import logging
logger = logging.getLogger(__name__)
DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value"))
class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))):
def __len__(self):
return len(self.value)
class DictionaryCache(object):
@ -32,7 +34,7 @@ class DictionaryCache(object):
"""
def __init__(self, name, max_entries=1000):
self.cache = LruCache(max_size=max_entries)
self.cache = LruCache(max_size=max_entries, size_callback=len)
self.name = name
self.sequence = 0

View file

@ -15,6 +15,7 @@
from synapse.util.caches import register_cache
from collections import OrderedDict
import logging
@ -23,7 +24,7 @@ logger = logging.getLogger(__name__)
class ExpiringCache(object):
def __init__(self, cache_name, clock, max_len=0, expiry_ms=0,
reset_expiry_on_get=False):
reset_expiry_on_get=False, iterable=False):
"""
Args:
cache_name (str): Name of this cache, used for logging.
@ -36,6 +37,8 @@ class ExpiringCache(object):
evicted based on time.
reset_expiry_on_get (bool): If true, will reset the expiry time for
an item on access. Defaults to False.
iterable (bool): If true, the size is calculated by summing the
sizes of all entries, rather than the number of entries.
"""
self._cache_name = cache_name
@ -47,9 +50,13 @@ class ExpiringCache(object):
self._reset_expiry_on_get = reset_expiry_on_get
self._cache = {}
self._cache = OrderedDict()
self.metrics = register_cache(cache_name, self._cache)
self.metrics = register_cache(cache_name, self)
self.iterable = iterable
self._size_estimate = 0
def start(self):
if not self._expiry_ms:
@ -65,15 +72,14 @@ class ExpiringCache(object):
now = self._clock.time_msec()
self._cache[key] = _CacheEntry(now, value)
# Evict if there are now too many items
if self._max_len and len(self._cache.keys()) > self._max_len:
sorted_entries = sorted(
self._cache.items(),
key=lambda item: item[1].time,
)
if self.iterable:
self._size_estimate += len(value)
for k, _ in sorted_entries[self._max_len:]:
self._cache.pop(k)
# Evict if there are now too many items
while self._max_len and len(self) > self._max_len:
_key, value = self._cache.popitem(last=False)
if self.iterable:
self._size_estimate -= len(value.value)
def __getitem__(self, key):
try:
@ -99,7 +105,7 @@ class ExpiringCache(object):
# zero expiry time means don't expire. This should never get called
# since we have this check in start too.
return
begin_length = len(self._cache)
begin_length = len(self)
now = self._clock.time_msec()
@ -110,15 +116,20 @@ class ExpiringCache(object):
keys_to_delete.add(key)
for k in keys_to_delete:
self._cache.pop(k)
value = self._cache.pop(k)
if self.iterable:
self._size_estimate -= len(value.value)
logger.debug(
"[%s] _prune_cache before: %d, after len: %d",
self._cache_name, begin_length, len(self._cache)
self._cache_name, begin_length, len(self)
)
def __len__(self):
return len(self._cache)
if self.iterable:
return self._size_estimate
else:
return len(self._cache)
class _CacheEntry(object):

View file

@ -49,7 +49,7 @@ class LruCache(object):
Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted.
"""
def __init__(self, max_size, keylen=1, cache_type=dict):
def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None):
cache = cache_type()
self.cache = cache # Used for introspection.
list_root = _Node(None, None, None, None)
@ -58,6 +58,12 @@ class LruCache(object):
lock = threading.Lock()
def evict():
while cache_len() > max_size:
todelete = list_root.prev_node
delete_node(todelete)
cache.pop(todelete.key, None)
def synchronized(f):
@wraps(f)
def inner(*args, **kwargs):
@ -66,6 +72,16 @@ class LruCache(object):
return inner
cached_cache_len = [0]
if size_callback is not None:
def cache_len():
return cached_cache_len[0]
else:
def cache_len():
return len(cache)
self.len = synchronized(cache_len)
def add_node(key, value, callbacks=set()):
prev_node = list_root
next_node = prev_node.next_node
@ -74,6 +90,9 @@ class LruCache(object):
next_node.prev_node = node
cache[key] = node
if size_callback:
cached_cache_len[0] += size_callback(node.value)
def move_node_to_front(node):
prev_node = node.prev_node
next_node = node.next_node
@ -92,23 +111,25 @@ class LruCache(object):
prev_node.next_node = next_node
next_node.prev_node = prev_node
if size_callback:
cached_cache_len[0] -= size_callback(node.value)
for cb in node.callbacks:
cb()
node.callbacks.clear()
@synchronized
def cache_get(key, default=None, callback=None):
def cache_get(key, default=None, callbacks=[]):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
if callback:
node.callbacks.add(callback)
node.callbacks.update(callbacks)
return node.value
else:
return default
@synchronized
def cache_set(key, value, callback=None):
def cache_set(key, value, callbacks=[]):
node = cache.get(key, None)
if node is not None:
if value != node.value:
@ -116,21 +137,18 @@ class LruCache(object):
cb()
node.callbacks.clear()
if callback:
node.callbacks.add(callback)
if size_callback:
cached_cache_len[0] -= size_callback(node.value)
cached_cache_len[0] += size_callback(value)
node.callbacks.update(callbacks)
move_node_to_front(node)
node.value = value
else:
if callback:
callbacks = set([callback])
else:
callbacks = set()
add_node(key, value, callbacks)
if len(cache) > max_size:
todelete = list_root.prev_node
delete_node(todelete)
cache.pop(todelete.key, None)
add_node(key, value, set(callbacks))
evict()
@synchronized
def cache_set_default(key, value):
@ -139,10 +157,7 @@ class LruCache(object):
return node.value
else:
add_node(key, value)
if len(cache) > max_size:
todelete = list_root.prev_node
delete_node(todelete)
cache.pop(todelete.key, None)
evict()
return value
@synchronized
@ -174,10 +189,8 @@ class LruCache(object):
for cb in node.callbacks:
cb()
cache.clear()
@synchronized
def cache_len():
return len(cache)
if size_callback:
cached_cache_len[0] = 0
@synchronized
def cache_contains(key):
@ -190,7 +203,7 @@ class LruCache(object):
self.pop = cache_pop
if cache_type is TreeCache:
self.del_multi = cache_del_multi
self.len = cache_len
self.len = synchronized(cache_len)
self.contains = cache_contains
self.clear = cache_clear

View file

@ -65,12 +65,27 @@ class TreeCache(object):
return popped
def values(self):
return [e.value for e in self.root.values()]
return list(iterate_tree_cache_entry(self.root))
def __len__(self):
return self.size
def iterate_tree_cache_entry(d):
"""Helper function to iterate over the leaves of a tree, i.e. a dict of that
can contain dicts.
"""
if isinstance(d, dict):
for value_d in d.itervalues():
for value in iterate_tree_cache_entry(value_d):
yield value
else:
if isinstance(d, _Entry):
yield d.value
else:
yield d
class _Entry(object):
__slots__ = ["value"]

View file

@ -1,71 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer, reactor
from functools import wraps
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
def debug_deferreds():
"""Cause all deferreds to wait for a reactor tick before running their
callbacks. This increases the chance of getting a stack trace out of
a defer.inlineCallback since the code waiting on the deferred will get
a chance to add an errback before the deferred runs."""
# Helper method for retrieving and restoring the current logging context
# around a callback.
def with_logging_context(fn):
context = LoggingContext.current_context()
def restore_context_callback(x):
with PreserveLoggingContext(context):
return fn(x)
return restore_context_callback
# We are going to modify the __init__ method of defer.Deferred so we
# need to get a copy of the old method so we can still call it.
old__init__ = defer.Deferred.__init__
# We need to create a deferred to bounce the callbacks through the reactor
# but we don't want to add a callback when we create that deferred so we
# we create a new type of deferred that uses the old __init__ method.
# This is safe as long as the old __init__ method doesn't invoke an
# __init__ using super.
class Bouncer(defer.Deferred):
__init__ = old__init__
# We'll add this as a callback to all Deferreds. Twisted will wait until
# the bouncer deferred resolves before calling the callbacks of the
# original deferred.
def bounce_callback(x):
bouncer = Bouncer()
reactor.callLater(0, with_logging_context(bouncer.callback), x)
return bouncer
# We'll add this as an errback to all Deferreds. Twisted will wait until
# the bouncer deferred resolves before calling the errbacks of the
# original deferred.
def bounce_errback(x):
bouncer = Bouncer()
reactor.callLater(0, with_logging_context(bouncer.errback), x)
return bouncer
@wraps(old__init__)
def new__init__(self, *args, **kargs):
old__init__(self, *args, **kargs)
self.addCallbacks(bounce_callback, bounce_errback)
defer.Deferred.__init__ = new__init__

View file

@ -25,10 +25,13 @@ from synapse.api.filtering import Filter
from synapse.events import FrozenEvent
user_localpart = "test_user"
# MockEvent = namedtuple("MockEvent", "sender type room_id")
def MockEvent(**kwargs):
if "event_id" not in kwargs:
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
return FrozenEvent(kwargs)

View file

@ -21,6 +21,10 @@ from synapse.events.utils import prune_event, serialize_event
def MockEvent(**kwargs):
if "event_id" not in kwargs:
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
return FrozenEvent(kwargs)
@ -35,9 +39,13 @@ class PruneEventTestCase(unittest.TestCase):
def test_minimal(self):
self.run_test(
{'type': 'A'},
{
'type': 'A',
'event_id': '$test:domain',
},
{
'type': 'A',
'event_id': '$test:domain',
'content': {},
'signatures': {},
'unsigned': {},
@ -69,10 +77,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test(
{
'type': 'B',
'event_id': '$test:domain',
'unsigned': {'age_ts': 20},
},
{
'type': 'B',
'event_id': '$test:domain',
'content': {},
'signatures': {},
'unsigned': {'age_ts': 20},
@ -82,10 +92,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test(
{
'type': 'B',
'event_id': '$test:domain',
'unsigned': {'other_key': 'here'},
},
{
'type': 'B',
'event_id': '$test:domain',
'content': {},
'signatures': {},
'unsigned': {},
@ -96,10 +108,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test(
{
'type': 'C',
'event_id': '$test:domain',
'content': {'things': 'here'},
},
{
'type': 'C',
'event_id': '$test:domain',
'content': {},
'signatures': {},
'unsigned': {},
@ -109,10 +123,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test(
{
'type': 'm.room.create',
'event_id': '$test:domain',
'content': {'creator': '@2:domain', 'other_field': 'here'},
},
{
'type': 'm.room.create',
'event_id': '$test:domain',
'content': {'creator': '@2:domain'},
'signatures': {},
'unsigned': {},
@ -255,6 +271,8 @@ class SerializeEventTestCase(unittest.TestCase):
self.assertEquals(
self.serialize(
MockEvent(
type="foo",
event_id="test",
room_id="!foo:bar",
content={
"foo": "bar",
@ -263,6 +281,8 @@ class SerializeEventTestCase(unittest.TestCase):
[]
),
{
"type": "foo",
"event_id": "test",
"room_id": "!foo:bar",
"content": {
"foo": "bar",

View file

@ -241,7 +241,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
callcount2 = [0]
class A(object):
@cached(max_entries=2)
@cached(max_entries=20) # HACK: This makes it 2 due to cache factor
def func(self, key):
callcount[0] += 1
return key
@ -258,6 +258,10 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func("foo3")
self.assertEquals(callcount[0], 3)

View file

@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
# Copyright 2017 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .. import unittest
from synapse.util.caches.expiringcache import ExpiringCache
from tests.utils import MockClock
class ExpiringCacheTestCase(unittest.TestCase):
def test_get_set(self):
clock = MockClock()
cache = ExpiringCache("test", clock, max_len=1)
cache["key"] = "value"
self.assertEquals(cache.get("key"), "value")
self.assertEquals(cache["key"], "value")
def test_eviction(self):
clock = MockClock()
cache = ExpiringCache("test", clock, max_len=2)
cache["key"] = "value"
cache["key2"] = "value2"
self.assertEquals(cache.get("key"), "value")
self.assertEquals(cache.get("key2"), "value2")
cache["key3"] = "value3"
self.assertEquals(cache.get("key"), None)
self.assertEquals(cache.get("key2"), "value2")
self.assertEquals(cache.get("key3"), "value3")
def test_iterable_eviction(self):
clock = MockClock()
cache = ExpiringCache("test", clock, max_len=5, iterable=True)
cache["key"] = [1]
cache["key2"] = [2, 3]
cache["key3"] = [4, 5]
self.assertEquals(cache.get("key"), [1])
self.assertEquals(cache.get("key2"), [2, 3])
self.assertEquals(cache.get("key3"), [4, 5])
cache["key4"] = [6, 7]
self.assertEquals(cache.get("key"), None)
self.assertEquals(cache.get("key2"), None)
self.assertEquals(cache.get("key3"), [4, 5])
self.assertEquals(cache.get("key4"), [6, 7])
def test_time_eviction(self):
clock = MockClock()
cache = ExpiringCache("test", clock, expiry_ms=1000)
cache.start()
cache["key"] = 1
clock.advance_time(0.5)
cache["key2"] = 2
self.assertEquals(cache.get("key"), 1)
self.assertEquals(cache.get("key2"), 2)
clock.advance_time(0.9)
self.assertEquals(cache.get("key"), None)
self.assertEquals(cache.get("key2"), 2)
clock.advance_time(1)
self.assertEquals(cache.get("key"), None)
self.assertEquals(cache.get("key2"), None)

View file

@ -93,7 +93,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
cache.set("key", "value")
self.assertFalse(m.called)
cache.get("key", callback=m)
cache.get("key", callbacks=[m])
self.assertFalse(m.called)
cache.get("key", "value")
@ -112,10 +112,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
cache.set("key", "value")
self.assertFalse(m.called)
cache.get("key", callback=m)
cache.get("key", callbacks=[m])
self.assertFalse(m.called)
cache.get("key", callback=m)
cache.get("key", callbacks=[m])
self.assertFalse(m.called)
cache.set("key", "value2")
@ -128,7 +128,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m = Mock()
cache = LruCache(1)
cache.set("key", "value", m)
cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called)
cache.set("key", "value")
@ -144,7 +144,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m = Mock()
cache = LruCache(1)
cache.set("key", "value", m)
cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called)
cache.pop("key")
@ -163,10 +163,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m4 = Mock()
cache = LruCache(4, 2, cache_type=TreeCache)
cache.set(("a", "1"), "value", m1)
cache.set(("a", "2"), "value", m2)
cache.set(("b", "1"), "value", m3)
cache.set(("b", "2"), "value", m4)
cache.set(("a", "1"), "value", callbacks=[m1])
cache.set(("a", "2"), "value", callbacks=[m2])
cache.set(("b", "1"), "value", callbacks=[m3])
cache.set(("b", "2"), "value", callbacks=[m4])
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
@ -185,8 +185,8 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m2 = Mock()
cache = LruCache(5)
cache.set("key1", "value", m1)
cache.set("key2", "value", m2)
cache.set("key1", "value", callbacks=[m1])
cache.set("key2", "value", callbacks=[m2])
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
@ -202,14 +202,14 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m3 = Mock(name="m3")
cache = LruCache(2)
cache.set("key1", "value", m1)
cache.set("key2", "value", m2)
cache.set("key1", "value", callbacks=[m1])
cache.set("key2", "value", callbacks=[m2])
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key3", "value", m3)
cache.set("key3", "value", callbacks=[m3])
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
@ -227,8 +227,33 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key1", "value", m1)
cache.set("key1", "value", callbacks=[m1])
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 1)
class LruCacheSizedTestCase(unittest.TestCase):
def test_evict(self):
cache = LruCache(5, size_callback=len)
cache["key1"] = [0]
cache["key2"] = [1, 2]
cache["key3"] = [3]
cache["key4"] = [4]
self.assertEquals(cache["key1"], [0])
self.assertEquals(cache["key2"], [1, 2])
self.assertEquals(cache["key3"], [3])
self.assertEquals(cache["key4"], [4])
self.assertEquals(len(cache), 5)
cache["key5"] = [5, 6]
self.assertEquals(len(cache), 4)
self.assertEquals(cache.get("key1"), None)
self.assertEquals(cache.get("key2"), None)
self.assertEquals(cache["key3"], [3])
self.assertEquals(cache["key4"], [4])
self.assertEquals(cache["key5"], [5, 6])