Port the Password Auth Providers module interface to the new generic interface (#10548)

Co-authored-by: Azrenbeth <7782548+Azrenbeth@users.noreply.github.com>
Co-authored-by: Brendan Abolivier <babolivier@matrix.org>
This commit is contained in:
Azrenbeth 2021-10-13 12:21:52 +01:00 committed by GitHub
parent 732bbf6737
commit cdd308845b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 789 additions and 224 deletions

View file

@ -0,0 +1 @@
Port the Password Auth Providers module interface to the new generic interface.

View file

@ -43,6 +43,7 @@
- [Third-party rules callbacks](modules/third_party_rules_callbacks.md) - [Third-party rules callbacks](modules/third_party_rules_callbacks.md)
- [Presence router callbacks](modules/presence_router_callbacks.md) - [Presence router callbacks](modules/presence_router_callbacks.md)
- [Account validity callbacks](modules/account_validity_callbacks.md) - [Account validity callbacks](modules/account_validity_callbacks.md)
- [Password auth provider callbacks](modules/password_auth_provider_callbacks.md)
- [Porting a legacy module to the new interface](modules/porting_legacy_module.md) - [Porting a legacy module to the new interface](modules/porting_legacy_module.md)
- [Workers](workers.md) - [Workers](workers.md)
- [Using `synctl` with Workers](synctl_workers.md) - [Using `synctl` with Workers](synctl_workers.md)

View file

@ -0,0 +1,153 @@
# Password auth provider callbacks
Password auth providers offer a way for server administrators to integrate
their Synapse installation with an external authentication system. The callbacks can be
registered by using the Module API's `register_password_auth_provider_callbacks` method.
## Callbacks
### `auth_checkers`
```
auth_checkers: Dict[Tuple[str,Tuple], Callable]
```
A dict mapping from tuples of a login type identifier (such as `m.login.password`) and a
tuple of field names (such as `("password", "secret_thing")`) to authentication checking
callbacks, which should be of the following form:
```python
async def check_auth(
user: str,
login_type: str,
login_dict: "synapse.module_api.JsonDict",
) -> Optional[
Tuple[
str,
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]]
]
]
```
The login type and field names should be provided by the user in the
request to the `/login` API. [The Matrix specification](https://matrix.org/docs/spec/client_server/latest#authentication-types)
defines some types, however user defined ones are also allowed.
The callback is passed the `user` field provided by the client (which might not be in
`@username:server` form), the login type, and a dictionary of login secrets passed by
the client.
If the authentication is successful, the module must return the user's Matrix ID (e.g.
`@alice:example.com`) and optionally a callback to be called with the response to the
`/login` request. If the module doesn't wish to return a callback, it must return `None`
instead.
If the authentication is unsuccessful, the module must return `None`.
### `check_3pid_auth`
```python
async def check_3pid_auth(
medium: str,
address: str,
password: str,
) -> Optional[
Tuple[
str,
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]]
]
]
```
Called when a user attempts to register or log in with a third party identifier,
such as email. It is passed the medium (eg. `email`), an address (eg. `jdoe@example.com`)
and the user's password.
If the authentication is successful, the module must return the user's Matrix ID (e.g.
`@alice:example.com`) and optionally a callback to be called with the response to the `/login` request.
If the module doesn't wish to return a callback, it must return None instead.
If the authentication is unsuccessful, the module must return None.
### `on_logged_out`
```python
async def on_logged_out(
user_id: str,
device_id: Optional[str],
access_token: str
) -> None
```
Called during a logout request for a user. It is passed the qualified user ID, the ID of the
deactivated device (if any: access tokens are occasionally created without an associated
device ID), and the (now deactivated) access token.
## Example
The example module below implements authentication checkers for two different login types:
- `my.login.type`
- Expects a `my_field` field to be sent to `/login`
- Is checked by the method: `self.check_my_login`
- `m.login.password` (defined in [the spec](https://matrix.org/docs/spec/client_server/latest#password-based))
- Expects a `password` field to be sent to `/login`
- Is checked by the method: `self.check_pass`
```python
from typing import Awaitable, Callable, Optional, Tuple
import synapse
from synapse import module_api
class MyAuthProvider:
def __init__(self, config: dict, api: module_api):
self.api = api
self.credentials = {
"bob": "building",
"@scoop:matrix.org": "digging",
}
api.register_password_auth_provider_callbacks(
auth_checkers={
("my.login_type", ("my_field",)): self.check_my_login,
("m.login.password", ("password",)): self.check_pass,
},
)
async def check_my_login(
self,
username: str,
login_type: str,
login_dict: "synapse.module_api.JsonDict",
) -> Optional[
Tuple[
str,
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]],
]
]:
if login_type != "my.login_type":
return None
if self.credentials.get(username) == login_dict.get("my_field"):
return self.api.get_qualified_user_id(username)
async def check_pass(
self,
username: str,
login_type: str,
login_dict: "synapse.module_api.JsonDict",
) -> Optional[
Tuple[
str,
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]],
]
]:
if login_type != "m.login.password":
return None
if self.credentials.get(username) == login_dict.get("password"):
return self.api.get_qualified_user_id(username)
```

View file

@ -12,6 +12,9 @@ should register this resource in its `__init__` method using the `register_web_r
method from the `ModuleApi` class (see [this section](writing_a_module.html#registering-a-web-resource) for method from the `ModuleApi` class (see [this section](writing_a_module.html#registering-a-web-resource) for
more info). more info).
There is no longer a `get_db_schema_files` callback provided for password auth provider modules. Any
changes to the database should now be made by the module using the module API class.
The module's author should also update any example in the module's configuration to only The module's author should also update any example in the module's configuration to only
use the new `modules` section in Synapse's configuration file (see [this section](index.html#using-modules) use the new `modules` section in Synapse's configuration file (see [this section](index.html#using-modules)
for more info). for more info).

View file

@ -1,3 +1,9 @@
<h2 style="color:red">
This page of the Synapse documentation is now deprecated. For up to date
documentation on setting up or writing a password auth provider module, please see
<a href="modules.md">this page</a>.
</h2>
# Password auth provider modules # Password auth provider modules
Password auth providers offer a way for server administrators to Password auth providers offer a way for server administrators to

View file

@ -2260,34 +2260,6 @@ email:
#email_validation: "[%(server_name)s] Validate your email" #email_validation: "[%(server_name)s] Validate your email"
# Password providers allow homeserver administrators to integrate
# their Synapse installation with existing authentication methods
# ex. LDAP, external tokens, etc.
#
# For more information and known implementations, please see
# https://matrix-org.github.io/synapse/latest/password_auth_providers.html
#
# Note: instances wishing to use SAML or CAS authentication should
# instead use the `saml2_config` or `cas_config` options,
# respectively.
#
password_providers:
# # Example config for an LDAP auth provider
# - module: "ldap_auth_provider.LdapAuthProvider"
# config:
# enabled: true
# uri: "ldap://ldap.example.com:389"
# start_tls: true
# base: "ou=users,dc=example,dc=com"
# attributes:
# uid: "cn"
# mail: "email"
# name: "givenName"
# #bind_dn:
# #bind_password:
# #filter: "(objectClass=posixAccount)"
## Push ## ## Push ##

View file

@ -42,6 +42,7 @@ from synapse.crypto import context_factory
from synapse.events.presence_router import load_legacy_presence_router from synapse.events.presence_router import load_legacy_presence_router
from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.logging.context import PreserveLoggingContext from synapse.logging.context import PreserveLoggingContext
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.metrics.jemalloc import setup_jemalloc_stats from synapse.metrics.jemalloc import setup_jemalloc_stats
@ -379,6 +380,7 @@ async def start(hs: "HomeServer"):
load_legacy_spam_checkers(hs) load_legacy_spam_checkers(hs)
load_legacy_third_party_event_rules(hs) load_legacy_third_party_event_rules(hs)
load_legacy_presence_router(hs) load_legacy_presence_router(hs)
load_legacy_password_auth_providers(hs)
# If we've configured an expiry time for caches, start the background job now. # If we've configured an expiry time for caches, start the background job now.
setup_expire_lru_cache_entries(hs) setup_expire_lru_cache_entries(hs)

View file

@ -25,6 +25,29 @@ class PasswordAuthProviderConfig(Config):
section = "authproviders" section = "authproviders"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
"""Parses the old password auth providers config. The config format looks like this:
password_providers:
# Example config for an LDAP auth provider
- module: "ldap_auth_provider.LdapAuthProvider"
config:
enabled: true
uri: "ldap://ldap.example.com:389"
start_tls: true
base: "ou=users,dc=example,dc=com"
attributes:
uid: "cn"
mail: "email"
name: "givenName"
#bind_dn:
#bind_password:
#filter: "(objectClass=posixAccount)"
We expect admins to use modules for this feature (which is why it doesn't appear
in the sample config file), but we want to keep support for it around for a bit
for backwards compatibility.
"""
self.password_providers: List[Tuple[Type, Any]] = [] self.password_providers: List[Tuple[Type, Any]] = []
providers = [] providers = []
@ -49,33 +72,3 @@ class PasswordAuthProviderConfig(Config):
) )
self.password_providers.append((provider_class, provider_config)) self.password_providers.append((provider_class, provider_config))
def generate_config_section(self, **kwargs):
return """\
# Password providers allow homeserver administrators to integrate
# their Synapse installation with existing authentication methods
# ex. LDAP, external tokens, etc.
#
# For more information and known implementations, please see
# https://matrix-org.github.io/synapse/latest/password_auth_providers.html
#
# Note: instances wishing to use SAML or CAS authentication should
# instead use the `saml2_config` or `cas_config` options,
# respectively.
#
password_providers:
# # Example config for an LDAP auth provider
# - module: "ldap_auth_provider.LdapAuthProvider"
# config:
# enabled: true
# uri: "ldap://ldap.example.com:389"
# start_tls: true
# base: "ou=users,dc=example,dc=com"
# attributes:
# uid: "cn"
# mail: "email"
# name: "givenName"
# #bind_dn:
# #bind_password:
# #filter: "(objectClass=posixAccount)"
"""

View file

@ -200,46 +200,13 @@ class AuthHandler:
self.bcrypt_rounds = hs.config.registration.bcrypt_rounds self.bcrypt_rounds = hs.config.registration.bcrypt_rounds
# we can't use hs.get_module_api() here, because to do so will create an self.password_auth_provider = hs.get_password_auth_provider()
# import loop.
#
# TODO: refactor this class to separate the lower-level stuff that
# ModuleApi can use from the higher-level stuff that uses ModuleApi, as
# better way to break the loop
account_handler = ModuleApi(hs, self)
self.password_providers = [
PasswordProvider.load(module, config, account_handler)
for module, config in hs.config.authproviders.password_providers
]
logger.info("Extra password_providers: %s", self.password_providers)
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.auth.password_enabled self._password_enabled = hs.config.auth.password_enabled
self._password_localdb_enabled = hs.config.auth.password_localdb_enabled self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
# start out by assuming PASSWORD is enabled; we will remove it later if not.
login_types = set()
if self._password_localdb_enabled:
login_types.add(LoginType.PASSWORD)
for provider in self.password_providers:
login_types.update(provider.get_supported_login_types().keys())
if not self._password_enabled:
login_types.discard(LoginType.PASSWORD)
# Some clients just pick the first type in the list. In this case, we want
# them to use PASSWORD (rather than token or whatever), so we want to make sure
# that comes first, where it's present.
self._supported_login_types = []
if LoginType.PASSWORD in login_types:
self._supported_login_types.append(LoginType.PASSWORD)
login_types.remove(LoginType.PASSWORD)
self._supported_login_types.extend(login_types)
# Ratelimiter for failed auth during UIA. Uses same ratelimit config # Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`. # as per `rc_login.failed_attempts`.
self._failed_uia_attempts_ratelimiter = Ratelimiter( self._failed_uia_attempts_ratelimiter = Ratelimiter(
@ -427,8 +394,7 @@ class AuthHandler:
ui_auth_types.add(LoginType.PASSWORD) ui_auth_types.add(LoginType.PASSWORD)
# also allow auth from password providers # also allow auth from password providers
for provider in self.password_providers: for t in self.password_auth_provider.get_supported_login_types().keys():
for t in provider.get_supported_login_types().keys():
if t == LoginType.PASSWORD and not self._password_enabled: if t == LoginType.PASSWORD and not self._password_enabled:
continue continue
ui_auth_types.add(t) ui_auth_types.add(t)
@ -1038,7 +1004,25 @@ class AuthHandler:
Returns: Returns:
login types login types
""" """
return self._supported_login_types # Load any login types registered by modules
# This is stored in the password_auth_provider so this doesn't trigger
# any callbacks
types = list(self.password_auth_provider.get_supported_login_types().keys())
# This list should include PASSWORD if (either _password_localdb_enabled is
# true or if one of the modules registered it) AND _password_enabled is true
# Also:
# Some clients just pick the first type in the list. In this case, we want
# them to use PASSWORD (rather than token or whatever), so we want to make sure
# that comes first, where it's present.
if LoginType.PASSWORD in types:
types.remove(LoginType.PASSWORD)
if self._password_enabled:
types.insert(0, LoginType.PASSWORD)
elif self._password_localdb_enabled and self._password_enabled:
types.insert(0, LoginType.PASSWORD)
return types
async def validate_login( async def validate_login(
self, self,
@ -1217,15 +1201,20 @@ class AuthHandler:
known_login_type = False known_login_type = False
for provider in self.password_providers: # Check if login_type matches a type registered by one of the modules
supported_login_types = provider.get_supported_login_types() # We don't need to remove LoginType.PASSWORD from the list if password login is
if login_type not in supported_login_types: # disabled, since if that were the case then by this point we know that the
# this password provider doesn't understand this login type # login_type is not LoginType.PASSWORD
continue supported_login_types = self.password_auth_provider.get_supported_login_types()
# check if the login type being used is supported by a module
if login_type in supported_login_types:
# Make a note that this login type is supported by the server
known_login_type = True known_login_type = True
# Get all the fields expected for this login types
login_fields = supported_login_types[login_type] login_fields = supported_login_types[login_type]
# go through the login submission and keep track of which required fields are
# provided/not provided
missing_fields = [] missing_fields = []
login_dict = {} login_dict = {}
for f in login_fields: for f in login_fields:
@ -1233,6 +1222,7 @@ class AuthHandler:
missing_fields.append(f) missing_fields.append(f)
else: else:
login_dict[f] = login_submission[f] login_dict[f] = login_submission[f]
# raise an error if any of the expected fields for that login type weren't provided
if missing_fields: if missing_fields:
raise SynapseError( raise SynapseError(
400, 400,
@ -1240,10 +1230,15 @@ class AuthHandler:
% (login_type, missing_fields), % (login_type, missing_fields),
) )
result = await provider.check_auth(username, login_type, login_dict) # call all of the check_auth hooks for that login_type
# it will return a result once the first success is found (or None otherwise)
result = await self.password_auth_provider.check_auth(
username, login_type, login_dict
)
if result: if result:
return result return result
# if no module managed to authenticate the user, then fallback to built in password based auth
if login_type == LoginType.PASSWORD and self._password_localdb_enabled: if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
known_login_type = True known_login_type = True
@ -1282,11 +1277,16 @@ class AuthHandler:
completed login/registration, or `None`. If authentication was completed login/registration, or `None`. If authentication was
unsuccessful, `user_id` and `callback` are both `None`. unsuccessful, `user_id` and `callback` are both `None`.
""" """
for provider in self.password_providers: # call all of the check_3pid_auth callbacks
result = await provider.check_3pid_auth(medium, address, password) # Result will be from the first callback that returns something other than None
# If all the callbacks return None, then result is also set to None
result = await self.password_auth_provider.check_3pid_auth(
medium, address, password
)
if result: if result:
return result return result
# if result is None then return (None, None)
return None, None return None, None
async def _check_local_password(self, user_id: str, password: str) -> Optional[str]: async def _check_local_password(self, user_id: str, password: str) -> Optional[str]:
@ -1365,9 +1365,8 @@ class AuthHandler:
user_info = await self.auth.get_user_by_access_token(access_token) user_info = await self.auth.get_user_by_access_token(access_token)
await self.store.delete_access_token(access_token) await self.store.delete_access_token(access_token)
# see if any of our auth providers want to know about this # see if any modules want to know about this
for provider in self.password_providers: await self.password_auth_provider.on_logged_out(
await provider.on_logged_out(
user_id=user_info.user_id, user_id=user_info.user_id,
device_id=user_info.device_id, device_id=user_info.device_id,
access_token=access_token, access_token=access_token,
@ -1398,10 +1397,9 @@ class AuthHandler:
user_id, except_token_id=except_token_id, device_id=device_id user_id, except_token_id=except_token_id, device_id=device_id
) )
# see if any of our auth providers want to know about this # see if any modules want to know about this
for provider in self.password_providers:
for token, _, device_id in tokens_and_devices: for token, _, device_id in tokens_and_devices:
await provider.on_logged_out( await self.password_auth_provider.on_logged_out(
user_id=user_id, device_id=device_id, access_token=token user_id=user_id, device_id=device_id, access_token=token
) )
@ -1811,40 +1809,228 @@ class MacaroonGenerator:
return macaroon return macaroon
class PasswordProvider: def load_legacy_password_auth_providers(hs: "HomeServer") -> None:
"""Wrapper for a password auth provider module module_api = hs.get_module_api()
for module, config in hs.config.authproviders.password_providers:
load_single_legacy_password_auth_provider(
module=module, config=config, api=module_api
)
This class abstracts out all of the backwards-compatibility hacks for
password providers, to provide a consistent interface.
"""
@classmethod def load_single_legacy_password_auth_provider(
def load( module: Type, config: JsonDict, api: ModuleApi
cls, module: Type, config: JsonDict, module_api: ModuleApi ) -> None:
) -> "PasswordProvider":
try: try:
pp = module(config=config, account_handler=module_api) provider = module(config=config, account_handler=api)
except Exception as e: except Exception as e:
logger.error("Error while initializing %r: %s", module, e) logger.error("Error while initializing %r: %s", module, e)
raise raise
return cls(pp, module_api)
def __init__(self, pp: "PasswordProvider", module_api: ModuleApi): # The known hooks. If a module implements a method who's name appears in this set
self._pp = pp # we'll want to register it
self._module_api = module_api password_auth_provider_methods = {
"check_3pid_auth",
"on_logged_out",
}
self._supported_login_types = {} # All methods that the module provides should be async, but this wasn't enforced
# in the old module system, so we wrap them if needed
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
# f might be None if the callback isn't implemented by the module. In this
# case we don't want to register a callback at all so we return None.
if f is None:
return None
# grandfather in check_password support # We need to wrap check_password because its old form would return a boolean
if hasattr(self._pp, "check_password"): # but we now want it to behave just like check_auth() and return the matrix id of
self._supported_login_types[LoginType.PASSWORD] = ("password",) # the user if authentication succeeded or None otherwise
if f.__name__ == "check_password":
g = getattr(self._pp, "get_supported_login_types", None) async def wrapped_check_password(
if g: username: str, login_type: str, login_dict: JsonDict
self._supported_login_types.update(g()) ) -> Optional[Tuple[str, Optional[Callable]]]:
# We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not
# None.
assert f is not None
def __str__(self) -> str: matrix_user_id = api.get_qualified_user_id(username)
return str(self._pp) password = login_dict["password"]
is_valid = await f(matrix_user_id, password)
if is_valid:
return matrix_user_id, None
return None
return wrapped_check_password
# We need to wrap check_auth as in the old form it could return
# just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
if f.__name__ == "check_auth":
async def wrapped_check_auth(
username: str, login_type: str, login_dict: JsonDict
) -> Optional[Tuple[str, Optional[Callable]]]:
# We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not
# None.
assert f is not None
result = await f(username, login_type, login_dict)
if isinstance(result, str):
return result, None
return result
return wrapped_check_auth
# We need to wrap check_3pid_auth as in the old form it could return
# just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
if f.__name__ == "check_3pid_auth":
async def wrapped_check_3pid_auth(
medium: str, address: str, password: str
) -> Optional[Tuple[str, Optional[Callable]]]:
# We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not
# None.
assert f is not None
result = await f(medium, address, password)
if isinstance(result, str):
return result, None
return result
return wrapped_check_3pid_auth
def run(*args: Tuple, **kwargs: Dict) -> Awaitable:
# mypy doesn't do well across function boundaries so we need to tell it
# f is definitely not None.
assert f is not None
return maybe_awaitable(f(*args, **kwargs))
return run
# populate hooks with the implemented methods, wrapped with async_wrapper
hooks = {
hook: async_wrapper(getattr(provider, hook, None))
for hook in password_auth_provider_methods
}
supported_login_types = {}
# call get_supported_login_types and add that to the dict
g = getattr(provider, "get_supported_login_types", None)
if g is not None:
# Note the old module style also called get_supported_login_types at loading time
# and it is synchronous
supported_login_types.update(g())
auth_checkers = {}
# Legacy modules have a check_auth method which expects to be called with one of
# the keys returned by get_supported_login_types. New style modules register a
# dictionary of login_type->check_auth_method mappings
check_auth = async_wrapper(getattr(provider, "check_auth", None))
if check_auth is not None:
for login_type, fields in supported_login_types.items():
# need tuple(fields) since fields can be any Iterable type (so may not be hashable)
auth_checkers[(login_type, tuple(fields))] = check_auth
# if it has a "check_password" method then it should handle all auth checks
# with login type of LoginType.PASSWORD
check_password = async_wrapper(getattr(provider, "check_password", None))
if check_password is not None:
# need to use a tuple here for ("password",) not a list since lists aren't hashable
auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password
api.register_password_auth_provider_callbacks(hooks, auth_checkers=auth_checkers)
CHECK_3PID_AUTH_CALLBACK = Callable[
[str, str, str],
Awaitable[
Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
],
]
ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable]
CHECK_AUTH_CALLBACK = Callable[
[str, str, JsonDict],
Awaitable[
Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
],
]
class PasswordAuthProvider:
"""
A class that the AuthHandler calls when authenticating users
It allows modules to provide alternative methods for authentication
"""
def __init__(self) -> None:
# lists of callbacks
self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
# Mapping from login type to login parameters
self._supported_login_types: Dict[str, Iterable[str]] = {}
# Mapping from login type to auth checker callbacks
self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}
def register_password_auth_provider_callbacks(
self,
check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
auth_checkers: Optional[Dict[Tuple[str, Tuple], CHECK_AUTH_CALLBACK]] = None,
) -> None:
# Register check_3pid_auth callback
if check_3pid_auth is not None:
self.check_3pid_auth_callbacks.append(check_3pid_auth)
# register on_logged_out callback
if on_logged_out is not None:
self.on_logged_out_callbacks.append(on_logged_out)
if auth_checkers is not None:
# register a new supported login_type
# Iterate through all of the types being registered
for (login_type, fields), callback in auth_checkers.items():
# Note: fields may be empty here. This would allow a modules auth checker to
# be called with just 'login_type' and no password or other secrets
# Need to check that all the field names are strings or may get nasty errors later
for f in fields:
if not isinstance(f, str):
raise RuntimeError(
"A module tried to register support for login type: %s with parameters %s"
" but all parameter names must be strings"
% (login_type, fields)
)
# 2 modules supporting the same login type must expect the same fields
# e.g. 1 can't expect "pass" if the other expects "password"
# so throw an exception if that happens
if login_type not in self._supported_login_types.get(login_type, []):
self._supported_login_types[login_type] = fields
else:
fields_currently_supported = self._supported_login_types.get(
login_type
)
if fields_currently_supported != fields:
raise RuntimeError(
"A module tried to register support for login type: %s with parameters %s"
" but another module had already registered support for that type with parameters %s"
% (login_type, fields, fields_currently_supported)
)
# Add the new method to the list of auth_checker_callbacks for this login type
self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]: def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
"""Get the login types supported by this password provider """Get the login types supported by this password provider
@ -1852,20 +2038,15 @@ class PasswordProvider:
Returns a map from a login type identifier (such as m.login.password) to an Returns a map from a login type identifier (such as m.login.password) to an
iterable giving the fields which must be provided by the user in the submission iterable giving the fields which must be provided by the user in the submission
to the /login API. to the /login API.
This wrapper adds m.login.password to the list if the underlying password
provider supports the check_password() api.
""" """
return self._supported_login_types return self._supported_login_types
async def check_auth( async def check_auth(
self, username: str, login_type: str, login_dict: JsonDict self, username: str, login_type: str, login_dict: JsonDict
) -> Optional[Tuple[str, Optional[Callable]]]: ) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
"""Check if the user has presented valid login credentials """Check if the user has presented valid login credentials
This wrapper also calls check_password() if the underlying password provider
supports the check_password() api and the login type is m.login.password.
Args: Args:
username: user id presented by the client. Either an MXID or an unqualified username: user id presented by the client. Either an MXID or an unqualified
username. username.
@ -1879,63 +2060,130 @@ class PasswordProvider:
user, and `callback` is an optional callback which will be called with the user, and `callback` is an optional callback which will be called with the
result from the /login call (including access_token, device_id, etc.) result from the /login call (including access_token, device_id, etc.)
""" """
# first grandfather in a call to check_password
if login_type == LoginType.PASSWORD: # Go through all callbacks for the login type until one returns with a value
check_password = getattr(self._pp, "check_password", None) # other than None (i.e. until a callback returns a success)
if check_password: for callback in self.auth_checker_callbacks[login_type]:
qualified_user_id = self._module_api.get_qualified_user_id(username) try:
is_valid = await check_password( result = await callback(username, login_type, login_dict)
qualified_user_id, login_dict["password"] except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue
if result is not None:
# Check that the callback returned a Tuple[str, Optional[Callable]]
# "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
# result is always the right type, but as it is 3rd party code it might not be
if not isinstance(result, tuple) or len(result) != 2:
logger.warning(
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
result,
) )
if is_valid: continue
return qualified_user_id, None
check_auth = getattr(self._pp, "check_auth", None) # pull out the two parts of the tuple so we can do type checking
if not check_auth: str_result, callback_result = result
return None
result = await check_auth(username, login_type, login_dict)
# Check if the return value is a str or a tuple # the 1st item in the tuple should be a str
if isinstance(result, str): if not isinstance(str_result, str):
# If it's a str, set callback function to None logger.warning( # type: ignore[unreachable]
return result, None "Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
result,
)
continue
# the second should be Optional[Callable]
if callback_result is not None:
if not callable(callback_result):
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
result,
)
continue
# The result is a (str, Optional[callback]) tuple so return the successful result
return result return result
# If this point has been reached then none of the callbacks successfully authenticated
# the user so return None
return None
async def check_3pid_auth( async def check_3pid_auth(
self, medium: str, address: str, password: str self, medium: str, address: str, password: str
) -> Optional[Tuple[str, Optional[Callable]]]: ) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
g = getattr(self._pp, "check_3pid_auth", None)
if not g:
return None
# This function is able to return a deferred that either # This function is able to return a deferred that either
# resolves None, meaning authentication failure, or upon # resolves None, meaning authentication failure, or upon
# success, to a str (which is the user_id) or a tuple of # success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run # (user_id, callback_func), where callback_func should be run
# after we've finished everything else # after we've finished everything else
result = await g(medium, address, password)
# Check if the return value is a str or a tuple for callback in self.check_3pid_auth_callbacks:
if isinstance(result, str): try:
# If it's a str, set callback function to None result = await callback(medium, address, password)
return result, None except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue
if result is not None:
# Check that the callback returned a Tuple[str, Optional[Callable]]
# "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
# result is always the right type, but as it is 3rd party code it might not be
if not isinstance(result, tuple) or len(result) != 2:
logger.warning(
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
result,
)
continue
# pull out the two parts of the tuple so we can do type checking
str_result, callback_result = result
# the 1st item in the tuple should be a str
if not isinstance(str_result, str):
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
result,
)
continue
# the second should be Optional[Callable]
if callback_result is not None:
if not callable(callback_result):
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
result,
)
continue
# The result is a (str, Optional[callback]) tuple so return the successful result
return result return result
# If this point has been reached then none of the callbacks successfully authenticated
# the user so return None
return None
async def on_logged_out( async def on_logged_out(
self, user_id: str, device_id: Optional[str], access_token: str self, user_id: str, device_id: Optional[str], access_token: str
) -> None: ) -> None:
g = getattr(self._pp, "on_logged_out", None)
if not g:
return
# This might return an awaitable, if it does block the log out # call all of the on_logged_out callbacks
# until it completes. for callback in self.on_logged_out_callbacks:
await maybe_awaitable( try:
g( callback(user_id, device_id, access_token)
user_id=user_id, except Exception as e:
device_id=device_id, logger.warning("Failed to run module API callback %s: %s", callback, e)
access_token=access_token, continue
)
)

View file

@ -45,6 +45,7 @@ from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.client.login import LoginResponse
from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.databases.main.roommember import ProfileInfo
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
@ -83,6 +84,8 @@ __all__ = [
"DirectServeJsonResource", "DirectServeJsonResource",
"ModuleApi", "ModuleApi",
"PRESENCE_ALL_USERS", "PRESENCE_ALL_USERS",
"LoginResponse",
"JsonDict",
] ]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -139,6 +142,7 @@ class ModuleApi:
self._spam_checker = hs.get_spam_checker() self._spam_checker = hs.get_spam_checker()
self._account_validity_handler = hs.get_account_validity_handler() self._account_validity_handler = hs.get_account_validity_handler()
self._third_party_event_rules = hs.get_third_party_event_rules() self._third_party_event_rules = hs.get_third_party_event_rules()
self._password_auth_provider = hs.get_password_auth_provider()
self._presence_router = hs.get_presence_router() self._presence_router = hs.get_presence_router()
################################################################################# #################################################################################
@ -164,6 +168,11 @@ class ModuleApi:
"""Registers callbacks for presence router capabilities.""" """Registers callbacks for presence router capabilities."""
return self._presence_router.register_presence_router_callbacks return self._presence_router.register_presence_router_callbacks
@property
def register_password_auth_provider_callbacks(self):
"""Registers callbacks for password auth provider capabilities."""
return self._password_auth_provider.register_password_auth_provider_callbacks
def register_web_resource(self, path: str, resource: IResource): def register_web_resource(self, path: str, resource: IResource):
"""Registers a web resource to be served at the given path. """Registers a web resource to be served at the given path.

View file

@ -65,7 +65,7 @@ from synapse.handlers.account_data import AccountDataHandler
from synapse.handlers.account_validity import AccountValidityHandler from synapse.handlers.account_validity import AccountValidityHandler
from synapse.handlers.admin import AdminHandler from synapse.handlers.admin import AdminHandler
from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler, MacaroonGenerator from synapse.handlers.auth import AuthHandler, MacaroonGenerator, PasswordAuthProvider
from synapse.handlers.cas import CasHandler from synapse.handlers.cas import CasHandler
from synapse.handlers.deactivate_account import DeactivateAccountHandler from synapse.handlers.deactivate_account import DeactivateAccountHandler
from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
@ -687,6 +687,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_third_party_event_rules(self) -> ThirdPartyEventRules: def get_third_party_event_rules(self) -> ThirdPartyEventRules:
return ThirdPartyEventRules(self) return ThirdPartyEventRules(self)
@cache_in_self
def get_password_auth_provider(self) -> PasswordAuthProvider:
return PasswordAuthProvider()
@cache_in_self @cache_in_self
def get_room_member_handler(self) -> RoomMemberHandler: def get_room_member_handler(self) -> RoomMemberHandler:
if self.config.worker.worker_app: if self.config.worker.worker_app:

View file

@ -549,6 +549,8 @@ def _apply_module_schemas(
database_engine: database_engine:
config: application config config: application config
""" """
# This is the old way for password_auth_provider modules to make changes
# to the database. This should instead be done using the module API
for (mod, _config) in config.authproviders.password_providers: for (mod, _config) in config.authproviders.password_providers:
if not hasattr(mod, "get_db_schema_files"): if not hasattr(mod, "get_db_schema_files"):
continue continue

View file

@ -20,6 +20,8 @@ from unittest.mock import Mock
from twisted.internet import defer from twisted.internet import defer
import synapse import synapse
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.module_api import ModuleApi
from synapse.rest.client import devices, login from synapse.rest.client import devices, login
from synapse.types import JsonDict from synapse.types import JsonDict
@ -36,8 +38,8 @@ ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_servi
mock_password_provider = Mock() mock_password_provider = Mock()
class PasswordOnlyAuthProvider: class LegacyPasswordOnlyAuthProvider:
"""A password_provider which only implements `check_password`.""" """A legacy password_provider which only implements `check_password`."""
@staticmethod @staticmethod
def parse_config(self): def parse_config(self):
@ -50,8 +52,8 @@ class PasswordOnlyAuthProvider:
return mock_password_provider.check_password(*args) return mock_password_provider.check_password(*args)
class CustomAuthProvider: class LegacyCustomAuthProvider:
"""A password_provider which implements a custom login type.""" """A legacy password_provider which implements a custom login type."""
@staticmethod @staticmethod
def parse_config(self): def parse_config(self):
@ -67,7 +69,23 @@ class CustomAuthProvider:
return mock_password_provider.check_auth(*args) return mock_password_provider.check_auth(*args)
class PasswordCustomAuthProvider: class CustomAuthProvider:
"""A module which registers password_auth_provider callbacks for a custom login type."""
@staticmethod
def parse_config(self):
pass
def __init__(self, config, api: ModuleApi):
api.register_password_auth_provider_callbacks(
auth_checkers={("test.login_type", ("test_field",)): self.check_auth},
)
def check_auth(self, *args):
return mock_password_provider.check_auth(*args)
class LegacyPasswordCustomAuthProvider:
"""A password_provider which implements password login via `check_auth`, as well """A password_provider which implements password login via `check_auth`, as well
as a custom type.""" as a custom type."""
@ -85,8 +103,32 @@ class PasswordCustomAuthProvider:
return mock_password_provider.check_auth(*args) return mock_password_provider.check_auth(*args)
def providers_config(*providers: Type[Any]) -> dict: class PasswordCustomAuthProvider:
"""Returns a config dict that will enable the given password auth providers""" """A module which registers password_auth_provider callbacks for a custom login type.
as well as a password login"""
@staticmethod
def parse_config(self):
pass
def __init__(self, config, api: ModuleApi):
api.register_password_auth_provider_callbacks(
auth_checkers={
("test.login_type", ("test_field",)): self.check_auth,
("m.login.password", ("password",)): self.check_auth,
},
)
pass
def check_auth(self, *args):
return mock_password_provider.check_auth(*args)
def check_pass(self, *args):
return mock_password_provider.check_password(*args)
def legacy_providers_config(*providers: Type[Any]) -> dict:
"""Returns a config dict that will enable the given legacy password auth providers"""
return { return {
"password_providers": [ "password_providers": [
{"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}} {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
@ -95,6 +137,16 @@ def providers_config(*providers: Type[Any]) -> dict:
} }
def providers_config(*providers: Type[Any]) -> dict:
"""Returns a config dict that will enable the given modules"""
return {
"modules": [
{"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
for provider in providers
]
}
class PasswordAuthProviderTests(unittest.HomeserverTestCase): class PasswordAuthProviderTests(unittest.HomeserverTestCase):
servlets = [ servlets = [
synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets,
@ -107,8 +159,21 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock() mock_password_provider.reset_mock()
super().setUp() super().setUp()
@override_config(providers_config(PasswordOnlyAuthProvider)) def make_homeserver(self, reactor, clock):
def test_password_only_auth_provider_login(self): hs = self.setup_test_homeserver()
# Load the modules into the homeserver
module_api = hs.get_module_api()
for module, config in hs.config.modules.loaded_modules:
module(config=config, api=module_api)
load_legacy_password_auth_providers(hs)
return hs
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_password_only_auth_progiver_login_legacy(self):
self.password_only_auth_provider_login_test_body()
def password_only_auth_provider_login_test_body(self):
# login flows should only have m.login.password # login flows should only have m.login.password
flows = self._get_login_flows() flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
@ -138,8 +203,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"@ USER🙂NAME :test", " pASS😢word " "@ USER🙂NAME :test", " pASS😢word "
) )
@override_config(providers_config(PasswordOnlyAuthProvider)) @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_password_only_auth_provider_ui_auth(self): def test_password_only_auth_provider_ui_auth_legacy(self):
self.password_only_auth_provider_ui_auth_test_body()
def password_only_auth_provider_ui_auth_test_body(self):
"""UI Auth should delegate correctly to the password provider""" """UI Auth should delegate correctly to the password provider"""
# create the user, otherwise access doesn't work # create the user, otherwise access doesn't work
@ -172,8 +240,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
mock_password_provider.check_password.assert_called_once_with("@u:test", "p") mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
@override_config(providers_config(PasswordOnlyAuthProvider)) @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_local_user_fallback_login(self): def test_local_user_fallback_login_legacy(self):
self.local_user_fallback_login_test_body()
def local_user_fallback_login_test_body(self):
"""rejected login should fall back to local db""" """rejected login should fall back to local db"""
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
@ -186,8 +257,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@localuser:test", channel.json_body["user_id"]) self.assertEqual("@localuser:test", channel.json_body["user_id"])
@override_config(providers_config(PasswordOnlyAuthProvider)) @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_local_user_fallback_ui_auth(self): def test_local_user_fallback_ui_auth_legacy(self):
self.local_user_fallback_ui_auth_test_body()
def local_user_fallback_ui_auth_test_body(self):
"""rejected login should fall back to local db""" """rejected login should fall back to local db"""
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
@ -223,11 +297,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config( @override_config(
{ {
**providers_config(PasswordOnlyAuthProvider), **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
"password_config": {"localdb_enabled": False}, "password_config": {"localdb_enabled": False},
} }
) )
def test_no_local_user_fallback_login(self): def test_no_local_user_fallback_login_legacy(self):
self.no_local_user_fallback_login_test_body()
def no_local_user_fallback_login_test_body(self):
"""localdb_enabled can block login with the local password""" """localdb_enabled can block login with the local password"""
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
@ -242,11 +319,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config( @override_config(
{ {
**providers_config(PasswordOnlyAuthProvider), **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
"password_config": {"localdb_enabled": False}, "password_config": {"localdb_enabled": False},
} }
) )
def test_no_local_user_fallback_ui_auth(self): def test_no_local_user_fallback_ui_auth_legacy(self):
self.no_local_user_fallback_ui_auth_test_body()
def no_local_user_fallback_ui_auth_test_body(self):
"""localdb_enabled can block ui auth with the local password""" """localdb_enabled can block ui auth with the local password"""
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
@ -280,11 +360,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config( @override_config(
{ {
**providers_config(PasswordOnlyAuthProvider), **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
"password_config": {"enabled": False}, "password_config": {"enabled": False},
} }
) )
def test_password_auth_disabled(self): def test_password_auth_disabled_legacy(self):
self.password_auth_disabled_test_body()
def password_auth_disabled_test_body(self):
"""password auth doesn't work if it's disabled across the board""" """password auth doesn't work if it's disabled across the board"""
# login flows should be empty # login flows should be empty
flows = self._get_login_flows() flows = self._get_login_flows()
@ -295,8 +378,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_password.assert_not_called() mock_password_provider.check_password.assert_not_called()
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
def test_custom_auth_provider_login_legacy(self):
self.custom_auth_provider_login_test_body()
@override_config(providers_config(CustomAuthProvider)) @override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_login(self): def test_custom_auth_provider_login(self):
self.custom_auth_provider_login_test_body()
def custom_auth_provider_login_test_body(self):
# login flows should have the custom flow and m.login.password, since we # login flows should have the custom flow and m.login.password, since we
# haven't disabled local password lookup. # haven't disabled local password lookup.
# (password must come first, because reasons) # (password must come first, because reasons)
@ -312,7 +402,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called() mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_auth.return_value = defer.succeed("@user:bz") mock_password_provider.check_auth.return_value = defer.succeed(
("@user:bz", None)
)
channel = self._send_login("test.login_type", "u", test_field="y") channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"]) self.assertEqual("@user:bz", channel.json_body["user_id"])
@ -325,7 +417,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# in these cases, but at least we can guard against the API changing # in these cases, but at least we can guard against the API changing
# unexpectedly # unexpectedly
mock_password_provider.check_auth.return_value = defer.succeed( mock_password_provider.check_auth.return_value = defer.succeed(
"@ MALFORMED! :bz" ("@ MALFORMED! :bz", None)
) )
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ") channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
@ -334,8 +426,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
" USER🙂NAME ", "test.login_type", {"test_field": " abc "} " USER🙂NAME ", "test.login_type", {"test_field": " abc "}
) )
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
def test_custom_auth_provider_ui_auth_legacy(self):
self.custom_auth_provider_ui_auth_test_body()
@override_config(providers_config(CustomAuthProvider)) @override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_ui_auth(self): def test_custom_auth_provider_ui_auth(self):
self.custom_auth_provider_ui_auth_test_body()
def custom_auth_provider_ui_auth_test_body(self):
# register the user and log in twice, to get two devices # register the user and log in twice, to get two devices
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
tok1 = self.login("localuser", "localpass") tok1 = self.login("localuser", "localpass")
@ -367,7 +466,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock() mock_password_provider.reset_mock()
# right params, but authing as the wrong user # right params, but authing as the wrong user
mock_password_provider.check_auth.return_value = defer.succeed("@user:bz") mock_password_provider.check_auth.return_value = defer.succeed(
("@user:bz", None)
)
body["auth"]["test_field"] = "foo" body["auth"]["test_field"] = "foo"
channel = self._delete_device(tok1, "dev2", body) channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 403) self.assertEqual(channel.code, 403)
@ -379,7 +480,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# and finally, succeed # and finally, succeed
mock_password_provider.check_auth.return_value = defer.succeed( mock_password_provider.check_auth.return_value = defer.succeed(
"@localuser:test" ("@localuser:test", None)
) )
channel = self._delete_device(tok1, "dev2", body) channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
@ -387,8 +488,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"localuser", "test.login_type", {"test_field": "foo"} "localuser", "test.login_type", {"test_field": "foo"}
) )
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
def test_custom_auth_provider_callback_legacy(self):
self.custom_auth_provider_callback_test_body()
@override_config(providers_config(CustomAuthProvider)) @override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_callback(self): def test_custom_auth_provider_callback(self):
self.custom_auth_provider_callback_test_body()
def custom_auth_provider_callback_test_body(self):
callback = Mock(return_value=defer.succeed(None)) callback = Mock(return_value=defer.succeed(None))
mock_password_provider.check_auth.return_value = defer.succeed( mock_password_provider.check_auth.return_value = defer.succeed(
@ -410,10 +518,22 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
for p in ["user_id", "access_token", "device_id", "home_server"]: for p in ["user_id", "access_token", "device_id", "home_server"]:
self.assertIn(p, call_args[0]) self.assertIn(p, call_args[0])
@override_config(
{
**legacy_providers_config(LegacyCustomAuthProvider),
"password_config": {"enabled": False},
}
)
def test_custom_auth_password_disabled_legacy(self):
self.custom_auth_password_disabled_test_body()
@override_config( @override_config(
{**providers_config(CustomAuthProvider), "password_config": {"enabled": False}} {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
) )
def test_custom_auth_password_disabled(self): def test_custom_auth_password_disabled(self):
self.custom_auth_password_disabled_test_body()
def custom_auth_password_disabled_test_body(self):
"""Test login with a custom auth provider where password login is disabled""" """Test login with a custom auth provider where password login is disabled"""
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
@ -425,6 +545,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called() mock_password_provider.check_auth.assert_not_called()
@override_config(
{
**legacy_providers_config(LegacyCustomAuthProvider),
"password_config": {"enabled": False, "localdb_enabled": False},
}
)
def test_custom_auth_password_disabled_localdb_enabled_legacy(self):
self.custom_auth_password_disabled_localdb_enabled_test_body()
@override_config( @override_config(
{ {
**providers_config(CustomAuthProvider), **providers_config(CustomAuthProvider),
@ -432,6 +561,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
} }
) )
def test_custom_auth_password_disabled_localdb_enabled(self): def test_custom_auth_password_disabled_localdb_enabled(self):
self.custom_auth_password_disabled_localdb_enabled_test_body()
def custom_auth_password_disabled_localdb_enabled_test_body(self):
"""Check the localdb_enabled == enabled == False """Check the localdb_enabled == enabled == False
Regression test for https://github.com/matrix-org/synapse/issues/8914: check Regression test for https://github.com/matrix-org/synapse/issues/8914: check
@ -448,6 +580,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called() mock_password_provider.check_auth.assert_not_called()
@override_config(
{
**legacy_providers_config(LegacyPasswordCustomAuthProvider),
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_login_legacy(self):
self.password_custom_auth_password_disabled_login_test_body()
@override_config( @override_config(
{ {
**providers_config(PasswordCustomAuthProvider), **providers_config(PasswordCustomAuthProvider),
@ -455,6 +596,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
} }
) )
def test_password_custom_auth_password_disabled_login(self): def test_password_custom_auth_password_disabled_login(self):
self.password_custom_auth_password_disabled_login_test_body()
def password_custom_auth_password_disabled_login_test_body(self):
"""log in with a custom auth provider which implements password, but password """log in with a custom auth provider which implements password, but password
login is disabled""" login is disabled"""
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
@ -466,6 +610,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
channel = self._send_password_login("localuser", "localpass") channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called() mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_password.assert_not_called()
@override_config(
{
**legacy_providers_config(LegacyPasswordCustomAuthProvider),
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_ui_auth_legacy(self):
self.password_custom_auth_password_disabled_ui_auth_test_body()
@override_config( @override_config(
{ {
@ -474,12 +628,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
} }
) )
def test_password_custom_auth_password_disabled_ui_auth(self): def test_password_custom_auth_password_disabled_ui_auth(self):
self.password_custom_auth_password_disabled_ui_auth_test_body()
def password_custom_auth_password_disabled_ui_auth_test_body(self):
"""UI Auth with a custom auth provider which implements password, but password """UI Auth with a custom auth provider which implements password, but password
login is disabled""" login is disabled"""
# register the user and log in twice via the test login type to get two devices, # register the user and log in twice via the test login type to get two devices,
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")
mock_password_provider.check_auth.return_value = defer.succeed( mock_password_provider.check_auth.return_value = defer.succeed(
"@localuser:test" ("@localuser:test", None)
) )
channel = self._send_login("test.login_type", "localuser", test_field="") channel = self._send_login("test.login_type", "localuser", test_field="")
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
@ -516,6 +673,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"Password login has been disabled.", channel.json_body["error"] "Password login has been disabled.", channel.json_body["error"]
) )
mock_password_provider.check_auth.assert_not_called() mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_password.assert_not_called()
mock_password_provider.reset_mock() mock_password_provider.reset_mock()
# successful auth # successful auth
@ -526,6 +684,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.check_auth.assert_called_once_with( mock_password_provider.check_auth.assert_called_once_with(
"localuser", "test.login_type", {"test_field": "x"} "localuser", "test.login_type", {"test_field": "x"}
) )
mock_password_provider.check_password.assert_not_called()
@override_config(
{
**legacy_providers_config(LegacyCustomAuthProvider),
"password_config": {"localdb_enabled": False},
}
)
def test_custom_auth_no_local_user_fallback_legacy(self):
self.custom_auth_no_local_user_fallback_test_body()
@override_config( @override_config(
{ {
@ -534,6 +702,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
} }
) )
def test_custom_auth_no_local_user_fallback(self): def test_custom_auth_no_local_user_fallback(self):
self.custom_auth_no_local_user_fallback_test_body()
def custom_auth_no_local_user_fallback_test_body(self):
"""Test login with a custom auth provider where the local db is disabled""" """Test login with a custom auth provider where the local db is disabled"""
self.register_user("localuser", "localpass") self.register_user("localuser", "localpass")