Merge pull request #221 from matrix-org/auth

Simplify LoginHander and AuthHandler
This commit is contained in:
Daniel Wagner-Hall 2015-08-14 17:02:22 +01:00
commit 30883d8409
9 changed files with 102 additions and 145 deletions

View file

@ -22,7 +22,6 @@ from .room import (
from .message import MessageHandler from .message import MessageHandler
from .events import EventStreamHandler, EventHandler from .events import EventStreamHandler, EventHandler
from .federation import FederationHandler from .federation import FederationHandler
from .login import LoginHandler
from .profile import ProfileHandler from .profile import ProfileHandler
from .presence import PresenceHandler from .presence import PresenceHandler
from .directory import DirectoryHandler from .directory import DirectoryHandler
@ -54,7 +53,6 @@ class Handlers(object):
self.profile_handler = ProfileHandler(hs) self.profile_handler = ProfileHandler(hs)
self.presence_handler = PresenceHandler(hs) self.presence_handler = PresenceHandler(hs)
self.room_list_handler = RoomListHandler(hs) self.room_list_handler = RoomListHandler(hs)
self.login_handler = LoginHandler(hs)
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)
self.typing_notification_handler = TypingNotificationHandler(hs) self.typing_notification_handler = TypingNotificationHandler(hs)
self.admin_handler = AdminHandler(hs) self.admin_handler = AdminHandler(hs)

View file

@ -47,17 +47,24 @@ class AuthHandler(BaseHandler):
self.sessions = {} self.sessions = {}
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip=None): def check_auth(self, flows, clientdict, clientip):
""" """
Takes a dictionary sent by the client in the login / registration Takes a dictionary sent by the client in the login / registration
protocol and handles the login flow. protocol and handles the login flow.
As a side effect, this function fills in the 'creds' key on the user's
session with a map, which maps each auth-type (str) to the relevant
identity authenticated by that auth-type (mostly str, but for captcha, bool).
Args: Args:
flows: list of list of stages flows (list): A list of login flows. Each flow is an ordered list of
authdict: The dictionary from the client root level, not the strings representing auth-types. At least one full
'auth' key: this method prompts for auth if none is sent. flow must be completed in order for auth to be successful.
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
clientip (str): The IP address of the client.
Returns: Returns:
A tuple of authed, dict, dict where authed is true if the client A tuple of (authed, dict, dict) where authed is true if the client
has successfully completed an auth flow. If it is true, the first has successfully completed an auth flow. If it is true, the first
dict contains the authenticated credentials of each stage. dict contains the authenticated credentials of each stage.
@ -75,7 +82,7 @@ class AuthHandler(BaseHandler):
del clientdict['auth'] del clientdict['auth']
if 'session' in authdict: if 'session' in authdict:
sid = authdict['session'] sid = authdict['session']
sess = self._get_session_info(sid) session = self._get_session_info(sid)
if len(clientdict) > 0: if len(clientdict) > 0:
# This was designed to allow the client to omit the parameters # This was designed to allow the client to omit the parameters
@ -87,20 +94,19 @@ class AuthHandler(BaseHandler):
# on a home server. # on a home server.
# Revisit: Assumimg the REST APIs do sensible validation, the data # Revisit: Assumimg the REST APIs do sensible validation, the data
# isn't arbintrary. # isn't arbintrary.
sess['clientdict'] = clientdict session['clientdict'] = clientdict
self._save_session(sess) self._save_session(session)
pass elif 'clientdict' in session:
elif 'clientdict' in sess: clientdict = session['clientdict']
clientdict = sess['clientdict']
if not authdict: if not authdict:
defer.returnValue( defer.returnValue(
(False, self._auth_dict_for_flows(flows, sess), clientdict) (False, self._auth_dict_for_flows(flows, session), clientdict)
) )
if 'creds' not in sess: if 'creds' not in session:
sess['creds'] = {} session['creds'] = {}
creds = sess['creds'] creds = session['creds']
# check auth type currently being presented # check auth type currently being presented
if 'type' in authdict: if 'type' in authdict:
@ -109,15 +115,15 @@ class AuthHandler(BaseHandler):
result = yield self.checkers[authdict['type']](authdict, clientip) result = yield self.checkers[authdict['type']](authdict, clientip)
if result: if result:
creds[authdict['type']] = result creds[authdict['type']] = result
self._save_session(sess) self._save_session(session)
for f in flows: for f in flows:
if len(set(f) - set(creds.keys())) == 0: if len(set(f) - set(creds.keys())) == 0:
logger.info("Auth completed with creds: %r", creds) logger.info("Auth completed with creds: %r", creds)
self._remove_session(sess) self._remove_session(session)
defer.returnValue((True, creds, clientdict)) defer.returnValue((True, creds, clientdict))
ret = self._auth_dict_for_flows(flows, sess) ret = self._auth_dict_for_flows(flows, session)
ret['completed'] = creds.keys() ret['completed'] = creds.keys()
defer.returnValue((False, ret, clientdict)) defer.returnValue((False, ret, clientdict))
@ -151,22 +157,13 @@ class AuthHandler(BaseHandler):
if "user" not in authdict or "password" not in authdict: if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM) raise LoginError(400, "", Codes.MISSING_PARAM)
user = authdict["user"] user_id = authdict["user"]
password = authdict["password"] password = authdict["password"]
if not user.startswith('@'): if not user_id.startswith('@'):
user = UserID.create(user, self.hs.hostname).to_string() user_id = UserID.create(user_id, self.hs.hostname).to_string()
user_info = yield self.store.get_user_by_id(user_id=user) self._check_password(user_id, password)
if not user_info: defer.returnValue(user_id)
logger.warn("Attempted to login as %s but they do not exist", user)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
stored_hash = user_info["password_hash"]
if bcrypt.checkpw(password, stored_hash):
defer.returnValue(user)
else:
logger.warn("Failed password login for user %s", user)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip): def _check_recaptcha(self, authdict, clientip):
@ -270,6 +267,58 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id] return self.sessions[session_id]
@defer.inlineCallbacks
def login_with_password(self, user_id, password):
"""
Authenticates the user with their username and password.
Used only by the v1 login API.
Args:
user_id (str): User ID
password (str): Password
Returns:
The access token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
self._check_password(user_id, password)
reg_handler = self.hs.get_handlers().registration_handler
access_token = reg_handler.generate_token(user_id)
logger.info("Adding token %s for user %s", access_token, user_id)
yield self.store.add_access_token_to_user(user_id, access_token)
defer.returnValue(access_token)
def _check_password(self, user_id, password):
"""Checks that user_id has passed password, raises LoginError if not."""
user_info = yield self.store.get_user_by_id(user_id=user_id)
if not user_info:
logger.warn("Attempted to login as %s but they do not exist", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
stored_hash = user_info["password_hash"]
if not bcrypt.checkpw(password, stored_hash):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def set_password(self, user_id, newpassword):
password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
yield self.store.user_set_password_hash(user_id, password_hash)
yield self.store.user_delete_access_tokens(user_id)
yield self.hs.get_pusherpool().remove_pushers_by_user(user_id)
yield self.store.flush_user(user_id)
@defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at):
yield self.store.user_add_threepid(
user_id, medium, address, validated_at,
self.hs.get_clock().time_msec()
)
def _save_session(self, session): def _save_session(self, session):
# TODO: Persistent storage # TODO: Persistent storage
logger.debug("Saving session %s", session) logger.debug("Saving session %s", session)

View file

@ -1,83 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.errors import LoginError, Codes
import bcrypt
import logging
logger = logging.getLogger(__name__)
class LoginHandler(BaseHandler):
def __init__(self, hs):
super(LoginHandler, self).__init__(hs)
self.hs = hs
@defer.inlineCallbacks
def login(self, user, password):
"""Login as the specified user with the specified password.
Args:
user (str): The user ID.
password (str): The password.
Returns:
The newly allocated access token.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
# TODO do this better, it can't go in __init__ else it cyclic loops
if not hasattr(self, "reg_handler"):
self.reg_handler = self.hs.get_handlers().registration_handler
# pull out the hash for this user if they exist
user_info = yield self.store.get_user_by_id(user_id=user)
if not user_info:
logger.warn("Attempted to login as %s but they do not exist", user)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
stored_hash = user_info["password_hash"]
if bcrypt.checkpw(password, stored_hash):
# generate an access token and store it.
token = self.reg_handler._generate_token(user)
logger.info("Adding token %s for user %s", token, user)
yield self.store.add_access_token_to_user(user, token)
defer.returnValue(token)
else:
logger.warn("Failed password login for user %s", user)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def set_password(self, user_id, newpassword, token_id=None):
password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
yield self.store.user_set_password_hash(user_id, password_hash)
yield self.store.user_delete_access_tokens_apart_from(user_id, token_id)
yield self.hs.get_pusherpool().remove_pushers_by_user_access_token(
user_id, token_id
)
yield self.store.flush_user(user_id)
@defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at):
yield self.store.user_add_threepid(
user_id, medium, address, validated_at,
self.hs.get_clock().time_msec()
)

View file

@ -91,7 +91,7 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
token = self._generate_token(user_id) token = self.generate_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
@ -111,7 +111,7 @@ class RegistrationHandler(BaseHandler):
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_is_valid(user_id) yield self.check_user_id_is_valid(user_id)
token = self._generate_token(user_id) token = self.generate_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
@ -161,7 +161,7 @@ class RegistrationHandler(BaseHandler):
400, "Invalid user localpart for this application service.", 400, "Invalid user localpart for this application service.",
errcode=Codes.EXCLUSIVE errcode=Codes.EXCLUSIVE
) )
token = self._generate_token(user_id) token = self.generate_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
@ -208,7 +208,7 @@ class RegistrationHandler(BaseHandler):
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_is_valid(user_id) yield self.check_user_id_is_valid(user_id)
token = self._generate_token(user_id) token = self.generate_token(user_id)
try: try:
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
@ -273,7 +273,7 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE errcode=Codes.EXCLUSIVE
) )
def _generate_token(self, user_id): def generate_token(self, user_id):
# urlsafe variant uses _ and - so use . as the separator and replace # urlsafe variant uses _ and - so use . as the separator and replace
# all =s with .s so http clients don't quote =s when it is used as # all =s with .s so http clients don't quote =s when it is used as
# query params. # query params.

View file

@ -94,17 +94,14 @@ class PusherPool:
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pushers_by_user_access_token(self, user_id, not_access_token_id): def remove_pushers_by_user(self, user_id):
all = yield self.store.get_all_pushers() all = yield self.store.get_all_pushers()
logger.info( logger.info(
"Removing all pushers for user %s except access token %s", "Removing all pushers for user %s",
user_id, not_access_token_id user_id,
) )
for p in all: for p in all:
if ( if p['user_name'] == user_id:
p['user_name'] == user_id and
p['access_token'] != not_access_token_id
):
logger.info( logger.info(
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name'] p['app_id'], p['pushkey'], p['user_name']

View file

@ -78,9 +78,8 @@ class LoginRestServlet(ClientV1RestServlet):
login_submission["user"] = UserID.create( login_submission["user"] = UserID.create(
login_submission["user"], self.hs.hostname).to_string() login_submission["user"], self.hs.hostname).to_string()
handler = self.handlers.login_handler token = yield self.handlers.auth_handler.login_with_password(
token = yield handler.login( user_id=login_submission["user"],
user=login_submission["user"],
password=login_submission["password"]) password=login_submission["password"])
result = { result = {

View file

@ -36,7 +36,6 @@ class PasswordRestServlet(RestServlet):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler self.auth_handler = hs.get_handlers().auth_handler
self.login_handler = hs.get_handlers().login_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -47,7 +46,7 @@ class PasswordRestServlet(RestServlet):
authed, result, params = yield self.auth_handler.check_auth([ authed, result, params = yield self.auth_handler.check_auth([
[LoginType.PASSWORD], [LoginType.PASSWORD],
[LoginType.EMAIL_IDENTITY] [LoginType.EMAIL_IDENTITY]
], body) ], body, self.hs.get_ip_from_request(request))
if not authed: if not authed:
defer.returnValue((401, result)) defer.returnValue((401, result))
@ -79,7 +78,7 @@ class PasswordRestServlet(RestServlet):
raise SynapseError(400, "", Codes.MISSING_PARAM) raise SynapseError(400, "", Codes.MISSING_PARAM)
new_password = params['new_password'] new_password = params['new_password']
yield self.login_handler.set_password( yield self.auth_handler.set_password(
user_id, new_password, None user_id, new_password, None
) )
@ -95,7 +94,6 @@ class ThreepidRestServlet(RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(ThreepidRestServlet, self).__init__() super(ThreepidRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.login_handler = hs.get_handlers().login_handler
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -135,7 +133,7 @@ class ThreepidRestServlet(RestServlet):
logger.warn("Couldn't add 3pid: invalid response from ID sevrer") logger.warn("Couldn't add 3pid: invalid response from ID sevrer")
raise SynapseError(500, "Invalid response from ID Server") raise SynapseError(500, "Invalid response from ID Server")
yield self.login_handler.add_threepid( yield self.auth_handler.add_threepid(
auth_user.to_string(), auth_user.to_string(),
threepid['medium'], threepid['medium'],
threepid['address'], threepid['address'],

View file

@ -50,7 +50,6 @@ class RegisterRestServlet(RestServlet):
self.auth_handler = hs.get_handlers().auth_handler self.auth_handler = hs.get_handlers().auth_handler
self.registration_handler = hs.get_handlers().registration_handler self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.login_handler = hs.get_handlers().login_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -143,7 +142,7 @@ class RegisterRestServlet(RestServlet):
if reqd not in threepid: if reqd not in threepid:
logger.info("Can't add incomplete 3pid") logger.info("Can't add incomplete 3pid")
else: else:
yield self.login_handler.add_threepid( yield self.auth_handler.add_threepid(
user_id, user_id,
threepid['medium'], threepid['medium'],
threepid['address'], threepid['address'],

View file

@ -112,16 +112,16 @@ class RegistrationStore(SQLBaseStore):
}) })
@defer.inlineCallbacks @defer.inlineCallbacks
def user_delete_access_tokens_apart_from(self, user_id, token_id): def user_delete_access_tokens(self, user_id):
yield self.runInteraction( yield self.runInteraction(
"user_delete_access_tokens_apart_from", "user_delete_access_tokens",
self._user_delete_access_tokens_apart_from, user_id, token_id self._user_delete_access_tokens, user_id
) )
def _user_delete_access_tokens_apart_from(self, txn, user_id, token_id): def _user_delete_access_tokens(self, txn, user_id):
txn.execute( txn.execute(
"DELETE FROM access_tokens WHERE user_id = ? AND id != ?", "DELETE FROM access_tokens WHERE user_id = ?",
(user_id, token_id) (user_id, )
) )
@defer.inlineCallbacks @defer.inlineCallbacks