Implement access token expiry (#5660)

Record how long an access token is valid for, and raise a soft-logout once it
expires.
This commit is contained in:
Richard van der Hoff 2019-07-12 17:26:02 +01:00 committed by GitHub
parent 24aa0e0a5b
commit 5f158ec039
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 255 additions and 33 deletions

1
changelog.d/5660.feature Normal file
View file

@ -0,0 +1 @@
Implement `session_lifetime` configuration option, after which access tokens will expire.

View file

@ -786,6 +786,17 @@ uploads_path: "DATADIR/uploads"
# renew_at: 1w
# renew_email_subject: "Renew your %(app)s account"
# Time that a user's session remains valid for, after they log in.
#
# Note that this is not currently compatible with guest logins.
#
# Note also that this is calculated at login time: changes are not applied
# retrospectively to users who have already logged in.
#
# By default, this is infinite.
#
#session_lifetime: 24h
# The user must provide all of the below types of 3PID when registering.
#
#registrations_require_3pid:

View file

@ -319,6 +319,17 @@ class Auth(object):
# first look in the database
r = yield self._look_up_user_by_access_token(token)
if r:
valid_until_ms = r["valid_until_ms"]
if (
valid_until_ms is not None
and valid_until_ms < self.clock.time_msec()
):
# there was a valid access token, but it has expired.
# soft-logout the user.
raise InvalidClientTokenError(
msg="Access token has expired", soft_logout=True
)
defer.returnValue(r)
# otherwise it needs to be a valid macaroon
@ -505,6 +516,7 @@ class Auth(object):
"token_id": ret.get("token_id", None),
"is_guest": False,
"device_id": ret.get("device_id"),
"valid_until_ms": ret.get("valid_until_ms"),
}
defer.returnValue(user_info)

View file

@ -245,8 +245,14 @@ class MissingClientTokenError(InvalidClientCredentialsError):
class InvalidClientTokenError(InvalidClientCredentialsError):
"""Raised when we didn't understand the access token in a request"""
def __init__(self, msg="Unrecognised access token"):
def __init__(self, msg="Unrecognised access token", soft_logout=False):
super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN")
self._soft_logout = soft_logout
def error_dict(self):
d = super().error_dict()
d["soft_logout"] = self._soft_logout
return d
class ResourceLimitError(SynapseError):

View file

@ -84,6 +84,11 @@ class RegistrationConfig(Config):
"disable_msisdn_registration", False
)
session_lifetime = config.get("session_lifetime")
if session_lifetime is not None:
session_lifetime = self.parse_duration(session_lifetime)
self.session_lifetime = session_lifetime
def generate_config_section(self, generate_secrets=False, **kwargs):
if generate_secrets:
registration_shared_secret = 'registration_shared_secret: "%s"' % (
@ -141,6 +146,17 @@ class RegistrationConfig(Config):
# renew_at: 1w
# renew_email_subject: "Renew your %%(app)s account"
# Time that a user's session remains valid for, after they log in.
#
# Note that this is not currently compatible with guest logins.
#
# Note also that this is calculated at login time: changes are not applied
# retrospectively to users who have already logged in.
#
# By default, this is infinite.
#
#session_lifetime: 24h
# The user must provide all of the below types of 3PID when registering.
#
#registrations_require_3pid:

View file

@ -15,6 +15,7 @@
# limitations under the License.
import logging
import time
import unicodedata
import attr
@ -558,7 +559,7 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id]
@defer.inlineCallbacks
def get_access_token_for_user_id(self, user_id, device_id=None):
def get_access_token_for_user_id(self, user_id, device_id, valid_until_ms):
"""
Creates a new access token for the user with the given user ID.
@ -572,16 +573,26 @@ class AuthHandler(BaseHandler):
device_id (str|None): the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID)
valid_until_ms (int|None): when the token is valid until. None for
no expiry.
Returns:
The access token for the user's session.
Raises:
StoreError if there was a problem storing the token.
"""
logger.info("Logging in user %s on device %s", user_id, device_id)
fmt_expiry = ""
if valid_until_ms is not None:
fmt_expiry = time.strftime(
" until %Y-%m-%d %H:%M:%S", time.localtime(valid_until_ms / 1000.0)
)
logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry)
yield self.auth.check_auth_blocking(user_id)
access_token = self.macaroon_gen.generate_access_token(user_id)
yield self.store.add_access_token_to_user(user_id, access_token, device_id)
yield self.store.add_access_token_to_user(
user_id, access_token, device_id, valid_until_ms
)
# the device *should* have been registered before we got here; however,
# it's possible we raced against a DELETE operation. The thing we

View file

@ -84,6 +84,8 @@ class RegistrationHandler(BaseHandler):
self.device_handler = hs.get_device_handler()
self.pusher_pool = hs.get_pusherpool()
self.session_lifetime = hs.config.session_lifetime
@defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None, assigned_user_id=None):
if types.contains_invalid_mxid_characters(localpart):
@ -599,6 +601,8 @@ class RegistrationHandler(BaseHandler):
def register_device(self, user_id, device_id, initial_display_name, is_guest=False):
"""Register a device for a user and generate an access token.
The access token will be limited by the homeserver's session_lifetime config.
Args:
user_id (str): full canonical @user:id
device_id (str|None): The device ID to check, or None to generate
@ -619,20 +623,29 @@ class RegistrationHandler(BaseHandler):
is_guest=is_guest,
)
defer.returnValue((r["device_id"], r["access_token"]))
else:
device_id = yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
if is_guest:
access_token = self.macaroon_gen.generate_access_token(
user_id, ["guest = true"]
)
else:
access_token = yield self._auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id
)
defer.returnValue((device_id, access_token))
valid_until_ms = None
if self.session_lifetime is not None:
if is_guest:
raise Exception(
"session_lifetime is not currently implemented for guest access"
)
valid_until_ms = self.clock.time_msec() + self.session_lifetime
device_id = yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
if is_guest:
assert valid_until_ms is None
access_token = self.macaroon_gen.generate_access_token(
user_id, ["guest = true"]
)
else:
access_token = yield self._auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id, valid_until_ms=valid_until_ms
)
defer.returnValue((device_id, access_token))
@defer.inlineCallbacks
def post_registration_actions(

View file

@ -90,7 +90,8 @@ class RegistrationWorkerStore(SQLBaseStore):
token (str): The access token of a user.
Returns:
defer.Deferred: None, if the token did not match, otherwise dict
including the keys `name`, `is_guest`, `device_id`, `token_id`.
including the keys `name`, `is_guest`, `device_id`, `token_id`,
`valid_until_ms`.
"""
return self.runInteraction(
"get_user_by_access_token", self._query_for_auth, token
@ -284,7 +285,7 @@ class RegistrationWorkerStore(SQLBaseStore):
def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
" access_tokens.device_id"
" access_tokens.device_id, access_tokens.valid_until_ms"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"
@ -679,14 +680,16 @@ class RegistrationStore(
defer.returnValue(batch_size)
@defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id=None):
def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
"""Adds an access token for the given user.
Args:
user_id (str): The user ID.
token (str): The new access token to add.
device_id (str): ID of the device to associate with the access
token
token
valid_until_ms (int|None): when the token is valid until. None for
no expiry.
Raises:
StoreError if there was a problem adding this.
"""
@ -694,7 +697,13 @@ class RegistrationStore(
yield self._simple_insert(
"access_tokens",
{"id": next_id, "user_id": user_id, "token": token, "device_id": device_id},
{
"id": next_id,
"user_id": user_id,
"token": token,
"device_id": device_id,
"valid_until_ms": valid_until_ms,
},
desc="add_access_token_to_user",
)

View file

@ -0,0 +1,18 @@
/* Copyright 2019 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- when this access token can be used until, in ms since the epoch. NULL means the token
-- never expires.
ALTER TABLE access_tokens ADD COLUMN valid_until_ms BIGINT;

View file

@ -262,9 +262,11 @@ class AuthTestCase(unittest.TestCase):
self.store.add_access_token_to_user = Mock()
token = yield self.hs.handlers.auth_handler.get_access_token_for_user_id(
USER_ID, "DEVICE"
USER_ID, "DEVICE", valid_until_ms=None
)
self.store.add_access_token_to_user.assert_called_with(
USER_ID, token, "DEVICE", None
)
self.store.add_access_token_to_user.assert_called_with(USER_ID, token, "DEVICE")
def get_user(tok):
if token != tok:

View file

@ -117,7 +117,9 @@ class AuthTestCase(unittest.TestCase):
def test_mau_limits_disabled(self):
self.hs.config.limit_usage_by_mau = False
# Ensure does not throw exception
yield self.auth_handler.get_access_token_for_user_id("user_a")
yield self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None
)
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
@ -131,7 +133,9 @@ class AuthTestCase(unittest.TestCase):
)
with self.assertRaises(ResourceLimitError):
yield self.auth_handler.get_access_token_for_user_id("user_a")
yield self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users)
@ -150,7 +154,9 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(ResourceLimitError):
yield self.auth_handler.get_access_token_for_user_id("user_a")
yield self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
@ -166,7 +172,9 @@ class AuthTestCase(unittest.TestCase):
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
yield self.auth_handler.get_access_token_for_user_id("user_a")
yield self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None
)
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec())
)
@ -185,7 +193,9 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(self.small_number_of_users)
)
# Ensure does not raise exception
yield self.auth_handler.get_access_token_for_user_id("user_a")
yield self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.small_number_of_users)

View file

@ -272,7 +272,10 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
else:
yield self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
yield self.store.add_access_token_to_user(
user_id=user_id, token=token, device_id=None, valid_until_ms=None
)
if displayname is not None:
# logger.info("setting user display name: %s -> %s", user_id, displayname)

View file

@ -2,10 +2,14 @@ import json
import synapse.rest.admin
from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import devices
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
from tests import unittest
from tests.unittest import override_config
LOGIN_URL = b"/_matrix/client/r0/login"
TEST_URL = b"/_matrix/client/r0/account/whoami"
class LoginRestServletTestCase(unittest.HomeserverTestCase):
@ -13,6 +17,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
devices.register_servlets,
lambda hs, http_server: WhoamiRestServlet(hs).register(http_server),
]
def make_homeserver(self, reactor, clock):
@ -144,3 +150,105 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
@override_config({"session_lifetime": "24h"})
def test_soft_logout(self):
self.register_user("kermit", "monkey")
# we shouldn't be able to make requests without an access token
request, channel = self.make_request(b"GET", TEST_URL)
self.render(request)
self.assertEquals(channel.result["code"], b"401", channel.result)
self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN")
# log in as normal
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
self.assertEquals(channel.code, 200, channel.result)
access_token = channel.json_body["access_token"]
device_id = channel.json_body["device_id"]
# we should now be able to make requests with the access token
request, channel = self.make_request(
b"GET", TEST_URL, access_token=access_token
)
self.render(request)
self.assertEquals(channel.code, 200, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
request, channel = self.make_request(
b"GET", TEST_URL, access_token=access_token
)
self.render(request)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], True)
#
# test behaviour after deleting the expired device
#
# we now log in as a different device
access_token_2 = self.login("kermit", "monkey")
# more requests with the expired token should still return a soft-logout
self.reactor.advance(3600)
request, channel = self.make_request(
b"GET", TEST_URL, access_token=access_token
)
self.render(request)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], True)
# ... but if we delete that device, it will be a proper logout
self._delete_device(access_token_2, "kermit", "monkey", device_id)
request, channel = self.make_request(
b"GET", TEST_URL, access_token=access_token
)
self.render(request)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], False)
def _delete_device(self, access_token, user_id, password, device_id):
"""Perform the UI-Auth to delete a device"""
request, channel = self.make_request(
b"DELETE", "devices/" + device_id, access_token=access_token
)
self.render(request)
self.assertEquals(channel.code, 401, channel.result)
# check it's a UI-Auth fail
self.assertEqual(
set(channel.json_body.keys()),
{"flows", "params", "session"},
channel.result,
)
auth = {
"type": "m.login.password",
# https://github.com/matrix-org/synapse/issues/5665
# "identifier": {"type": "m.id.user", "user": user_id},
"user": user_id,
"password": password,
"session": channel.json_body["session"],
}
request, channel = self.make_request(
b"DELETE",
"devices/" + device_id,
access_token=access_token,
content={"auth": auth},
)
self.render(request)
self.assertEquals(channel.code, 200, channel.result)

View file

@ -57,7 +57,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
def test_add_tokens(self):
yield self.store.register_user(self.user_id, self.pwhash)
yield self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
)
result = yield self.store.get_user_by_access_token(self.tokens[1])
@ -72,9 +72,11 @@ class RegistrationStoreTestCase(unittest.TestCase):
def test_user_delete_access_tokens(self):
# add some tokens
yield self.store.register_user(self.user_id, self.pwhash)
yield self.store.add_access_token_to_user(self.user_id, self.tokens[0])
yield self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id
self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
)
yield self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
)
# now delete some