Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2017-03-13 10:06:28 +00:00
commit a0d6987991
14 changed files with 104 additions and 401 deletions

View file

@ -1,6 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -45,7 +44,6 @@ class JoinRules(object):
class LoginType(object): class LoginType(object):
PASSWORD = u"m.login.password" PASSWORD = u"m.login.password"
EMAIL_IDENTITY = u"m.login.email.identity" EMAIL_IDENTITY = u"m.login.email.identity"
MSISDN = u"m.login.msisdn"
RECAPTCHA = u"m.login.recaptcha" RECAPTCHA = u"m.login.recaptcha"
DUMMY = u"m.login.dummy" DUMMY = u"m.login.dummy"

View file

@ -206,8 +206,7 @@ class FederationClient(FederationBase):
Args: Args:
destinations (list): Which home servers to query destinations (list): Which home servers to query
pdu_origin (str): The home server that originally sent the pdu. event_id (str): event to fetch
event_id (str)
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False` of the current block of PDUs. Defaults to `False`

View file

@ -1,6 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd # Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -48,7 +47,6 @@ class AuthHandler(BaseHandler):
LoginType.PASSWORD: self._check_password_auth, LoginType.PASSWORD: self._check_password_auth,
LoginType.RECAPTCHA: self._check_recaptcha, LoginType.RECAPTCHA: self._check_recaptcha,
LoginType.EMAIL_IDENTITY: self._check_email_identity, LoginType.EMAIL_IDENTITY: self._check_email_identity,
LoginType.MSISDN: self._check_msisdn,
LoginType.DUMMY: self._check_dummy_auth, LoginType.DUMMY: self._check_dummy_auth,
} }
self.bcrypt_rounds = hs.config.bcrypt_rounds self.bcrypt_rounds = hs.config.bcrypt_rounds
@ -309,47 +307,31 @@ class AuthHandler(BaseHandler):
defer.returnValue(True) defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@defer.inlineCallbacks
def _check_email_identity(self, authdict, _): def _check_email_identity(self, authdict, _):
return self._check_threepid('email', authdict)
def _check_msisdn(self, authdict, _):
return self._check_threepid('msisdn', authdict)
@defer.inlineCallbacks
def _check_dummy_auth(self, authdict, _):
yield run_on_reactor()
defer.returnValue(True)
@defer.inlineCallbacks
def _check_threepid(self, medium, authdict):
yield run_on_reactor() yield run_on_reactor()
if 'threepid_creds' not in authdict: if 'threepid_creds' not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
threepid_creds = authdict['threepid_creds'] threepid_creds = authdict['threepid_creds']
identity_handler = self.hs.get_handlers().identity_handler identity_handler = self.hs.get_handlers().identity_handler
logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,)) logger.info("Getting validated threepid. threepidcreds: %r" % (threepid_creds,))
threepid = yield identity_handler.threepid_from_creds(threepid_creds) threepid = yield identity_handler.threepid_from_creds(threepid_creds)
if not threepid: if not threepid:
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
if threepid['medium'] != medium:
raise LoginError(
401,
"Expecting threepid of type '%s', got '%s'" % (
medium, threepid['medium'],
),
errcode=Codes.UNAUTHORIZED
)
threepid['threepid_creds'] = authdict['threepid_creds'] threepid['threepid_creds'] = authdict['threepid_creds']
defer.returnValue(threepid) defer.returnValue(threepid)
@defer.inlineCallbacks
def _check_dummy_auth(self, authdict, _):
yield run_on_reactor()
defer.returnValue(True)
def _get_params_recaptcha(self): def _get_params_recaptcha(self):
return {"public_key": self.hs.config.recaptcha_public_key} return {"public_key": self.hs.config.recaptcha_public_key}

View file

@ -1,6 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -151,7 +150,7 @@ class IdentityHandler(BaseHandler):
params.update(kwargs) params.update(kwargs)
try: try:
data = yield self.http_client.post_json_get_json( data = yield self.http_client.post_urlencoded_get_json(
"https://%s%s" % ( "https://%s%s" % (
id_server, id_server,
"/_matrix/identity/api/v1/validate/email/requestToken" "/_matrix/identity/api/v1/validate/email/requestToken"
@ -162,37 +161,3 @@ class IdentityHandler(BaseHandler):
except CodeMessageException as e: except CodeMessageException as e:
logger.info("Proxied requestToken failed: %r", e) logger.info("Proxied requestToken failed: %r", e)
raise e raise e
@defer.inlineCallbacks
def requestMsisdnToken(
self, id_server, country, phone_number,
client_secret, send_attempt, **kwargs
):
yield run_on_reactor()
if not self._should_trust_id_server(id_server):
raise SynapseError(
400, "Untrusted ID server '%s'" % id_server,
Codes.SERVER_NOT_TRUSTED
)
params = {
'country': country,
'phone_number': phone_number,
'client_secret': client_secret,
'send_attempt': send_attempt,
}
params.update(kwargs)
try:
data = yield self.http_client.post_json_get_json(
"https://%s%s" % (
id_server,
"/_matrix/identity/api/v1/validate/msisdn/requestToken"
),
params
)
defer.returnValue(data)
except CodeMessageException as e:
logger.info("Proxied requestToken failed: %r", e)
raise e

View file

@ -609,14 +609,14 @@ class SyncHandler(object):
deleted = yield self.store.delete_messages_for_device( deleted = yield self.store.delete_messages_for_device(
user_id, device_id, since_stream_id user_id, device_id, since_stream_id
) )
logger.info("Deleted %d to-device messages up to %d", logger.debug("Deleted %d to-device messages up to %d",
deleted, since_stream_id) deleted, since_stream_id)
messages, stream_id = yield self.store.get_new_messages_for_device( messages, stream_id = yield self.store.get_new_messages_for_device(
user_id, device_id, since_stream_id, now_token.to_device_key user_id, device_id, since_stream_id, now_token.to_device_key
) )
logger.info( logger.debug(
"Returning %d to-device messages between %d and %d (current token: %d)", "Returning %d to-device messages between %d and %d (current token: %d)",
len(messages), since_stream_id, stream_id, now_token.to_device_key len(messages), since_stream_id, stream_id, now_token.to_device_key
) )

View file

@ -192,16 +192,6 @@ def parse_json_object_from_request(request):
return content return content
def assert_params_in_request(body, required):
absent = []
for k in required:
if k not in body:
absent.append(k)
if len(absent) > 0:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
class RestServlet(object): class RestServlet(object):
""" A Synapse REST Servlet. """ A Synapse REST Servlet.

View file

@ -1,5 +1,4 @@
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -38,7 +37,6 @@ REQUIREMENTS = {
"pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"], "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
"pymacaroons-pynacl": ["pymacaroons"], "pymacaroons-pynacl": ["pymacaroons"],
"msgpack-python>=0.3.0": ["msgpack"], "msgpack-python>=0.3.0": ["msgpack"],
"phonenumbers>=8.2.0": ["phonenumbers"],
} }
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {
"web_client": { "web_client": {

View file

@ -17,6 +17,7 @@ from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.caches.expiringcache import ExpiringCache
class SlavedDeviceInboxStore(BaseSlavedStore): class SlavedDeviceInboxStore(BaseSlavedStore):
@ -34,6 +35,13 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
self._device_inbox_id_gen.get_current_token() self._device_inbox_id_gen.get_current_token()
) )
self._last_device_delete_cache = ExpiringCache(
cache_name="last_device_delete_cache",
clock=self._clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
)
get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__ get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__ get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__
get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__ get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__

View file

@ -19,7 +19,6 @@ from synapse.api.errors import SynapseError, LoginError, Codes
from synapse.types import UserID from synapse.types import UserID
from synapse.http.server import finish_request from synapse.http.server import finish_request
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.util.msisdn import phone_number_to_msisdn
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -38,49 +37,6 @@ import xml.etree.ElementTree as ET
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def login_submission_legacy_convert(submission):
"""
If the input login submission is an old style object
(ie. with top-level user / medium / address) convert it
to a typed object.
"""
if "user" in submission:
submission["identifier"] = {
"type": "m.id.user",
"user": submission["user"],
}
del submission["user"]
if "medium" in submission and "address" in submission:
submission["identifier"] = {
"type": "m.id.thirdparty",
"medium": submission["medium"],
"address": submission["address"],
}
del submission["medium"]
del submission["address"]
def login_id_thirdparty_from_phone(identifier):
"""
Convert a phone login identifier type to a generic threepid identifier
Args:
identifier(dict): Login identifier dict of type 'm.id.phone'
Returns: Login identifier dict of type 'm.id.threepid'
"""
if "country" not in identifier or "number" not in identifier:
raise SynapseError(400, "Invalid phone-type identifier")
msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"])
return {
"type": "m.id.thirdparty",
"medium": "msisdn",
"address": msisdn,
}
class LoginRestServlet(ClientV1RestServlet): class LoginRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login$") PATTERNS = client_path_patterns("/login$")
PASS_TYPE = "m.login.password" PASS_TYPE = "m.login.password"
@ -161,52 +117,20 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def do_password_login(self, login_submission): def do_password_login(self, login_submission):
if "password" not in login_submission: if 'medium' in login_submission and 'address' in login_submission:
raise SynapseError(400, "Missing parameter: password") address = login_submission['address']
if login_submission['medium'] == 'email':
login_submission_legacy_convert(login_submission)
if "identifier" not in login_submission:
raise SynapseError(400, "Missing param: identifier")
identifier = login_submission["identifier"]
if "type" not in identifier:
raise SynapseError(400, "Login identifier has no type")
# convert phone type identifiers to generic threepids
if identifier["type"] == "m.id.phone":
identifier = login_id_thirdparty_from_phone(identifier)
# convert threepid identifiers to user IDs
if identifier["type"] == "m.id.thirdparty":
if 'medium' not in identifier or 'address' not in identifier:
raise SynapseError(400, "Invalid thirdparty identifier")
address = identifier['address']
if identifier['medium'] == 'email':
# For emails, transform the address to lowercase. # For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB. # We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py) # (See add_threepid in synapse/handlers/auth.py)
address = address.lower() address = address.lower()
user_id = yield self.hs.get_datastore().get_user_id_by_threepid( user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
identifier['medium'], address login_submission['medium'], address
) )
if not user_id: if not user_id:
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)
else:
identifier = { user_id = login_submission['user']
"type": "m.id.user",
"user": user_id,
}
# by this point, the identifier should be an m.id.user: if it's anything
# else, we haven't understood it.
if identifier["type"] != "m.id.user":
raise SynapseError(400, "Unknown login identifier type")
if "user" not in identifier:
raise SynapseError(400, "User identifier is missing 'user' key")
user_id = identifier["user"]
if not user_id.startswith('@'): if not user_id.startswith('@'):
user_id = UserID.create( user_id = UserID.create(

View file

@ -1,6 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -18,11 +17,8 @@ from twisted.internet import defer
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import LoginError, SynapseError, Codes from synapse.api.errors import LoginError, SynapseError, Codes
from synapse.http.servlet import ( from synapse.http.servlet import RestServlet, parse_json_object_from_request
RestServlet, parse_json_object_from_request, assert_params_in_request
)
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.msisdn import phone_number_to_msisdn
from ._base import client_v2_patterns from ._base import client_v2_patterns
@ -32,11 +28,11 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EmailPasswordRequestTokenRestServlet(RestServlet): class PasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password/email/requestToken$") PATTERNS = client_v2_patterns("/account/password/email/requestToken$")
def __init__(self, hs): def __init__(self, hs):
super(EmailPasswordRequestTokenRestServlet, self).__init__() super(PasswordRequestTokenRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
@ -44,9 +40,14 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_request(body, [ required = ['id_server', 'client_secret', 'email', 'send_attempt']
'id_server', 'client_secret', 'email', 'send_attempt' absent = []
]) for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email'] 'email', body['email']
@ -59,37 +60,6 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
defer.returnValue((200, ret)) defer.returnValue((200, ret))
class MsisdnPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password/msisdn/requestToken$")
def __init__(self, hs):
super(MsisdnPasswordRequestTokenRestServlet, self).__init__()
self.hs = hs
self.datastore = self.hs.get_datastore()
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
assert_params_in_request(body, [
'id_server', 'client_secret',
'country', 'phone_number', 'send_attempt',
])
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
existingUid = yield self.datastore.get_user_id_by_threepid(
'msisdn', msisdn
)
if existingUid is None:
raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND)
ret = yield self.identity_handler.requestMsisdnToken(**body)
defer.returnValue((200, ret))
class PasswordRestServlet(RestServlet): class PasswordRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password$") PATTERNS = client_v2_patterns("/account/password$")
@ -98,7 +68,6 @@ class PasswordRestServlet(RestServlet):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -108,8 +77,7 @@ class PasswordRestServlet(RestServlet):
authed, result, params, _ = yield self.auth_handler.check_auth([ authed, result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD], [LoginType.PASSWORD],
[LoginType.EMAIL_IDENTITY], [LoginType.EMAIL_IDENTITY]
[LoginType.MSISDN],
], body, self.hs.get_ip_from_request(request)) ], body, self.hs.get_ip_from_request(request))
if not authed: if not authed:
@ -134,7 +102,7 @@ class PasswordRestServlet(RestServlet):
# (See add_threepid in synapse/handlers/auth.py) # (See add_threepid in synapse/handlers/auth.py)
threepid['address'] = threepid['address'].lower() threepid['address'] = threepid['address'].lower()
# if using email, we must know about the email they're authing with! # if using email, we must know about the email they're authing with!
threepid_user_id = yield self.datastore.get_user_id_by_threepid( threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
threepid['medium'], threepid['address'] threepid['medium'], threepid['address']
) )
if not threepid_user_id: if not threepid_user_id:
@ -201,14 +169,13 @@ class DeactivateAccountRestServlet(RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class EmailThreepidRequestTokenRestServlet(RestServlet): class ThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$") PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$")
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
super(EmailThreepidRequestTokenRestServlet, self).__init__() super(ThreepidRequestTokenRestServlet, self).__init__()
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.datastore = self.hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -223,7 +190,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
if absent: if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.datastore.get_user_id_by_threepid( existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email'] 'email', body['email']
) )
@ -234,44 +201,6 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
defer.returnValue((200, ret)) defer.returnValue((200, ret))
class MsisdnThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/msisdn/requestToken$")
def __init__(self, hs):
self.hs = hs
super(MsisdnThreepidRequestTokenRestServlet, self).__init__()
self.identity_handler = hs.get_handlers().identity_handler
self.datastore = self.hs.get_datastore()
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = [
'id_server', 'client_secret',
'country', 'phone_number', 'send_attempt',
]
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
existingUid = yield self.datastore.get_user_id_by_threepid(
'msisdn', msisdn
)
if existingUid is not None:
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
class ThreepidRestServlet(RestServlet): class ThreepidRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid$") PATTERNS = client_v2_patterns("/account/3pid$")
@ -281,7 +210,6 @@ class ThreepidRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -289,7 +217,7 @@ class ThreepidRestServlet(RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
threepids = yield self.datastore.user_get_threepids( threepids = yield self.hs.get_datastore().user_get_threepids(
requester.user.to_string() requester.user.to_string()
) )
@ -330,7 +258,7 @@ class ThreepidRestServlet(RestServlet):
if 'bind' in body and body['bind']: if 'bind' in body and body['bind']:
logger.debug( logger.debug(
"Binding threepid %s to %s", "Binding emails %s to %s",
threepid, user_id threepid, user_id
) )
yield self.identity_handler.bind_threepid( yield self.identity_handler.bind_threepid(
@ -374,11 +302,9 @@ class ThreepidDeleteRestServlet(RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
EmailPasswordRequestTokenRestServlet(hs).register(http_server) PasswordRequestTokenRestServlet(hs).register(http_server)
MsisdnPasswordRequestTokenRestServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server)
EmailThreepidRequestTokenRestServlet(hs).register(http_server) ThreepidRequestTokenRestServlet(hs).register(http_server)
MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server) ThreepidRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server) ThreepidDeleteRestServlet(hs).register(http_server)

View file

@ -1,6 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 - 2016 OpenMarket Ltd # Copyright 2015 - 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -20,10 +19,7 @@ import synapse
from synapse.api.auth import get_access_token_from_request, has_access_token from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
from synapse.http.servlet import ( from synapse.http.servlet import RestServlet, parse_json_object_from_request
RestServlet, parse_json_object_from_request, assert_params_in_request
)
from synapse.util.msisdn import phone_number_to_msisdn
from ._base import client_v2_patterns from ._base import client_v2_patterns
@ -47,7 +43,7 @@ else:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EmailRegisterRequestTokenRestServlet(RestServlet): class RegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/email/requestToken$") PATTERNS = client_v2_patterns("/register/email/requestToken$")
def __init__(self, hs): def __init__(self, hs):
@ -55,7 +51,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
Args: Args:
hs (synapse.server.HomeServer): server hs (synapse.server.HomeServer): server
""" """
super(EmailRegisterRequestTokenRestServlet, self).__init__() super(RegisterRequestTokenRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
@ -63,9 +59,14 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_request(body, [ required = ['id_server', 'client_secret', 'email', 'send_attempt']
'id_server', 'client_secret', 'email', 'send_attempt' absent = []
]) for k in required:
if k not in body:
absent.append(k)
if len(absent) > 0:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email'] 'email', body['email']
@ -78,43 +79,6 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
defer.returnValue((200, ret)) defer.returnValue((200, ret))
class MsisdnRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/msisdn/requestToken$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(MsisdnRegisterRequestTokenRestServlet, self).__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
assert_params_in_request(body, [
'id_server', 'client_secret',
'country', 'phone_number',
'send_attempt',
])
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'msisdn', msisdn
)
if existingUid is not None:
raise SynapseError(
400, "Phone number is already in use", Codes.THREEPID_IN_USE
)
ret = yield self.identity_handler.requestMsisdnToken(**body)
defer.returnValue((200, ret))
class RegisterRestServlet(RestServlet): class RegisterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register$") PATTERNS = client_v2_patterns("/register$")
@ -239,16 +203,12 @@ class RegisterRestServlet(RestServlet):
if self.hs.config.enable_registration_captcha: if self.hs.config.enable_registration_captcha:
flows = [ flows = [
[LoginType.RECAPTCHA], [LoginType.RECAPTCHA],
[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA], [LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]
[LoginType.MSISDN, LoginType.RECAPTCHA],
[LoginType.EMAIL_IDENTITY, LoginType.MSISDN, LoginType.RECAPTCHA],
] ]
else: else:
flows = [ flows = [
[LoginType.DUMMY], [LoginType.DUMMY],
[LoginType.EMAIL_IDENTITY], [LoginType.EMAIL_IDENTITY]
[LoginType.MSISDN],
[LoginType.EMAIL_IDENTITY, LoginType.MSISDN],
] ]
authed, auth_result, params, session_id = yield self.auth_handler.check_auth( authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
@ -264,9 +224,8 @@ class RegisterRestServlet(RestServlet):
"Already registered user ID %r for this session", "Already registered user ID %r for this session",
registered_user_id registered_user_id
) )
# don't re-register the threepids # don't re-register the email address
add_email = False add_email = False
add_msisdn = False
else: else:
# NB: This may be from the auth handler and NOT from the POST # NB: This may be from the auth handler and NOT from the POST
if 'password' not in params: if 'password' not in params:
@ -291,7 +250,6 @@ class RegisterRestServlet(RestServlet):
) )
add_email = True add_email = True
add_msisdn = True
return_dict = yield self._create_registration_details( return_dict = yield self._create_registration_details(
registered_user_id, params registered_user_id, params
@ -304,13 +262,6 @@ class RegisterRestServlet(RestServlet):
params.get("bind_email") params.get("bind_email")
) )
if add_msisdn and auth_result and LoginType.MSISDN in auth_result:
threepid = auth_result[LoginType.MSISDN]
yield self._register_msisdn_threepid(
registered_user_id, threepid, return_dict["access_token"],
params.get("bind_msisdn")
)
defer.returnValue((200, return_dict)) defer.returnValue((200, return_dict))
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
@ -372,9 +323,8 @@ class RegisterRestServlet(RestServlet):
""" """
reqd = ('medium', 'address', 'validated_at') reqd = ('medium', 'address', 'validated_at')
if any(x not in threepid for x in reqd): if any(x not in threepid for x in reqd):
# This will only happen if the ID server returns a malformed response
logger.info("Can't add incomplete 3pid") logger.info("Can't add incomplete 3pid")
return defer.returnValue()
yield self.auth_handler.add_threepid( yield self.auth_handler.add_threepid(
user_id, user_id,
@ -421,43 +371,6 @@ class RegisterRestServlet(RestServlet):
else: else:
logger.info("bind_email not specified: not binding email") logger.info("bind_email not specified: not binding email")
@defer.inlineCallbacks
def _register_msisdn_threepid(self, user_id, threepid, token, bind_msisdn):
"""Add a phone number as a 3pid identifier
Also optionally binds msisdn to the given user_id on the identity server
Args:
user_id (str): id of user
threepid (object): m.login.msisdn auth response
token (str): access_token for the user
bind_email (bool): true if the client requested the email to be
bound at the identity server
Returns:
defer.Deferred:
"""
reqd = ('medium', 'address', 'validated_at')
if any(x not in threepid for x in reqd):
# This will only happen if the ID server returns a malformed response
logger.info("Can't add incomplete 3pid")
defer.returnValue()
yield self.auth_handler.add_threepid(
user_id,
threepid['medium'],
threepid['address'],
threepid['validated_at'],
)
if bind_msisdn:
logger.info("bind_msisdn specified: binding")
logger.debug("Binding msisdn %s to %s", threepid, user_id)
yield self.identity_handler.bind_threepid(
threepid['threepid_creds'], user_id
)
else:
logger.info("bind_msisdn not specified: not binding msisdn")
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_registration_details(self, user_id, params): def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user """Complete registration of newly-registered user
@ -536,6 +449,5 @@ class RegisterRestServlet(RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
EmailRegisterRequestTokenRestServlet(hs).register(http_server) RegisterRequestTokenRestServlet(hs).register(http_server)
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
RegisterRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server)

View file

@ -20,6 +20,8 @@ from twisted.internet import defer
from .background_updates import BackgroundUpdateStore from .background_updates import BackgroundUpdateStore
from synapse.util.caches.expiringcache import ExpiringCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -42,6 +44,15 @@ class DeviceInboxStore(BackgroundUpdateStore):
self._background_drop_index_device_inbox, self._background_drop_index_device_inbox,
) )
# Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions.
self._last_device_delete_cache = ExpiringCache(
cache_name="last_device_delete_cache",
clock=self._clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_messages_to_device_inbox(self, local_messages_by_user_then_device, def add_messages_to_device_inbox(self, local_messages_by_user_then_device,
remote_messages_by_destination): remote_messages_by_destination):
@ -251,6 +262,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
"get_new_messages_for_device", get_new_messages_for_device_txn, "get_new_messages_for_device", get_new_messages_for_device_txn,
) )
@defer.inlineCallbacks
def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
""" """
Args: Args:
@ -260,6 +272,18 @@ class DeviceInboxStore(BackgroundUpdateStore):
Returns: Returns:
A deferred that resolves to the number of messages deleted. A deferred that resolves to the number of messages deleted.
""" """
# If we have cached the last stream id we've deleted up to, we can
# check if there is likely to be anything that needs deleting
last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), None
)
if last_deleted_stream_id:
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_deleted_stream_id
)
if not has_changed:
defer.returnValue(0)
def delete_messages_for_device_txn(txn): def delete_messages_for_device_txn(txn):
sql = ( sql = (
"DELETE FROM device_inbox" "DELETE FROM device_inbox"
@ -269,10 +293,20 @@ class DeviceInboxStore(BackgroundUpdateStore):
txn.execute(sql, (user_id, device_id, up_to_stream_id)) txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount return txn.rowcount
return self.runInteraction( count = yield self.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn "delete_messages_for_device", delete_messages_for_device_txn
) )
# Update the cache, ensuring that we only ever increase the value
last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), 0
)
self._last_device_delete_cache[(user_id, device_id)] = max(
last_deleted_stream_id, up_to_stream_id
)
defer.returnValue(count)
def get_all_new_device_messages(self, last_pos, current_pos, limit): def get_all_new_device_messages(self, last_pos, current_pos, limit):
""" """
Args: Args:

View file

@ -100,6 +100,13 @@ class ExpiringCache(object):
except KeyError: except KeyError:
return default return default
def setdefault(self, key, value):
try:
return self[key]
except KeyError:
self[key] = value
return value
def _prune_cache(self): def _prune_cache(self):
if not self._expiry_ms: if not self._expiry_ms:
# zero expiry time means don't expire. This should never get called # zero expiry time means don't expire. This should never get called

View file

@ -1,40 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import phonenumbers
from synapse.api.errors import SynapseError
def phone_number_to_msisdn(country, number):
"""
Takes an ISO-3166-1 2 letter country code and phone number and
returns an msisdn representing the canonical version of that
phone number.
Args:
country (str): ISO-3166-1 2 letter country code
number (str): Phone number in a national or international format
Returns:
(str) The canonical form of the phone number, as an msisdn
Raises:
SynapseError if the number could not be parsed.
"""
try:
phoneNumber = phonenumbers.parse(number, country)
except phonenumbers.NumberParseException:
raise SynapseError(400, "Unable to parse phone number")
return phonenumbers.format_number(
phoneNumber, phonenumbers.PhoneNumberFormat.E164
)[1:]