Port the Password Auth Providers module interface to the new generic interface (#10548)

Co-authored-by: Azrenbeth <7782548+Azrenbeth@users.noreply.github.com>
Co-authored-by: Brendan Abolivier <babolivier@matrix.org>
This commit is contained in:
Azrenbeth 2021-10-13 12:21:52 +01:00 committed by GitHub
parent 732bbf6737
commit cdd308845b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 789 additions and 224 deletions

View file

@ -0,0 +1 @@
Port the Password Auth Providers module interface to the new generic interface.

View file

@ -43,6 +43,7 @@
- [Third-party rules callbacks](modules/third_party_rules_callbacks.md)
- [Presence router callbacks](modules/presence_router_callbacks.md)
- [Account validity callbacks](modules/account_validity_callbacks.md)
- [Password auth provider callbacks](modules/password_auth_provider_callbacks.md)
- [Porting a legacy module to the new interface](modules/porting_legacy_module.md)
- [Workers](workers.md)
- [Using `synctl` with Workers](synctl_workers.md)

View file

@ -0,0 +1,153 @@
# Password auth provider callbacks
Password auth providers offer a way for server administrators to integrate
their Synapse installation with an external authentication system. The callbacks can be
registered by using the Module API's `register_password_auth_provider_callbacks` method.
## Callbacks
### `auth_checkers`
```
auth_checkers: Dict[Tuple[str,Tuple], Callable]
```
A dict mapping from tuples of a login type identifier (such as `m.login.password`) and a
tuple of field names (such as `("password", "secret_thing")`) to authentication checking
callbacks, which should be of the following form:
```python
async def check_auth(
user: str,
login_type: str,
login_dict: "synapse.module_api.JsonDict",
) -> Optional[
Tuple[
str,
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]]
]
]
```
The login type and field names should be provided by the user in the
request to the `/login` API. [The Matrix specification](https://matrix.org/docs/spec/client_server/latest#authentication-types)
defines some types, however user defined ones are also allowed.
The callback is passed the `user` field provided by the client (which might not be in
`@username:server` form), the login type, and a dictionary of login secrets passed by
the client.
If the authentication is successful, the module must return the user's Matrix ID (e.g.
`@alice:example.com`) and optionally a callback to be called with the response to the
`/login` request. If the module doesn't wish to return a callback, it must return `None`
instead.
If the authentication is unsuccessful, the module must return `None`.
### `check_3pid_auth`
```python
async def check_3pid_auth(
medium: str,
address: str,
password: str,
) -> Optional[
Tuple[
str,
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]]
]
]
```
Called when a user attempts to register or log in with a third party identifier,
such as email. It is passed the medium (eg. `email`), an address (eg. `jdoe@example.com`)
and the user's password.
If the authentication is successful, the module must return the user's Matrix ID (e.g.
`@alice:example.com`) and optionally a callback to be called with the response to the `/login` request.
If the module doesn't wish to return a callback, it must return None instead.
If the authentication is unsuccessful, the module must return None.
### `on_logged_out`
```python
async def on_logged_out(
user_id: str,
device_id: Optional[str],
access_token: str
) -> None
```
Called during a logout request for a user. It is passed the qualified user ID, the ID of the
deactivated device (if any: access tokens are occasionally created without an associated
device ID), and the (now deactivated) access token.
## Example
The example module below implements authentication checkers for two different login types:
- `my.login.type`
- Expects a `my_field` field to be sent to `/login`
- Is checked by the method: `self.check_my_login`
- `m.login.password` (defined in [the spec](https://matrix.org/docs/spec/client_server/latest#password-based))
- Expects a `password` field to be sent to `/login`
- Is checked by the method: `self.check_pass`
```python
from typing import Awaitable, Callable, Optional, Tuple
import synapse
from synapse import module_api
class MyAuthProvider:
def __init__(self, config: dict, api: module_api):
self.api = api
self.credentials = {
"bob": "building",
"@scoop:matrix.org": "digging",
}
api.register_password_auth_provider_callbacks(
auth_checkers={
("my.login_type", ("my_field",)): self.check_my_login,
("m.login.password", ("password",)): self.check_pass,
},
)
async def check_my_login(
self,
username: str,
login_type: str,
login_dict: "synapse.module_api.JsonDict",
) -> Optional[
Tuple[
str,
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]],
]
]:
if login_type != "my.login_type":
return None
if self.credentials.get(username) == login_dict.get("my_field"):
return self.api.get_qualified_user_id(username)
async def check_pass(
self,
username: str,
login_type: str,
login_dict: "synapse.module_api.JsonDict",
) -> Optional[
Tuple[
str,
Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]],
]
]:
if login_type != "m.login.password":
return None
if self.credentials.get(username) == login_dict.get("password"):
return self.api.get_qualified_user_id(username)
```

View file

@ -12,6 +12,9 @@ should register this resource in its `__init__` method using the `register_web_r
method from the `ModuleApi` class (see [this section](writing_a_module.html#registering-a-web-resource) for
more info).
There is no longer a `get_db_schema_files` callback provided for password auth provider modules. Any
changes to the database should now be made by the module using the module API class.
The module's author should also update any example in the module's configuration to only
use the new `modules` section in Synapse's configuration file (see [this section](index.html#using-modules)
for more info).

View file

@ -1,3 +1,9 @@
<h2 style="color:red">
This page of the Synapse documentation is now deprecated. For up to date
documentation on setting up or writing a password auth provider module, please see
<a href="modules.md">this page</a>.
</h2>
# Password auth provider modules
Password auth providers offer a way for server administrators to

View file

@ -2260,34 +2260,6 @@ email:
#email_validation: "[%(server_name)s] Validate your email"
# Password providers allow homeserver administrators to integrate
# their Synapse installation with existing authentication methods
# ex. LDAP, external tokens, etc.
#
# For more information and known implementations, please see
# https://matrix-org.github.io/synapse/latest/password_auth_providers.html
#
# Note: instances wishing to use SAML or CAS authentication should
# instead use the `saml2_config` or `cas_config` options,
# respectively.
#
password_providers:
# # Example config for an LDAP auth provider
# - module: "ldap_auth_provider.LdapAuthProvider"
# config:
# enabled: true
# uri: "ldap://ldap.example.com:389"
# start_tls: true
# base: "ou=users,dc=example,dc=com"
# attributes:
# uid: "cn"
# mail: "email"
# name: "givenName"
# #bind_dn:
# #bind_password:
# #filter: "(objectClass=posixAccount)"
## Push ##

View file

@ -42,6 +42,7 @@ from synapse.crypto import context_factory
from synapse.events.presence_router import load_legacy_presence_router
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.logging.context import PreserveLoggingContext
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.metrics.jemalloc import setup_jemalloc_stats
@ -379,6 +380,7 @@ async def start(hs: "HomeServer"):
load_legacy_spam_checkers(hs)
load_legacy_third_party_event_rules(hs)
load_legacy_presence_router(hs)
load_legacy_password_auth_providers(hs)
# If we've configured an expiry time for caches, start the background job now.
setup_expire_lru_cache_entries(hs)

View file

@ -25,6 +25,29 @@ class PasswordAuthProviderConfig(Config):
section = "authproviders"
def read_config(self, config, **kwargs):
"""Parses the old password auth providers config. The config format looks like this:
password_providers:
# Example config for an LDAP auth provider
- module: "ldap_auth_provider.LdapAuthProvider"
config:
enabled: true
uri: "ldap://ldap.example.com:389"
start_tls: true
base: "ou=users,dc=example,dc=com"
attributes:
uid: "cn"
mail: "email"
name: "givenName"
#bind_dn:
#bind_password:
#filter: "(objectClass=posixAccount)"
We expect admins to use modules for this feature (which is why it doesn't appear
in the sample config file), but we want to keep support for it around for a bit
for backwards compatibility.
"""
self.password_providers: List[Tuple[Type, Any]] = []
providers = []
@ -49,33 +72,3 @@ class PasswordAuthProviderConfig(Config):
)
self.password_providers.append((provider_class, provider_config))
def generate_config_section(self, **kwargs):
return """\
# Password providers allow homeserver administrators to integrate
# their Synapse installation with existing authentication methods
# ex. LDAP, external tokens, etc.
#
# For more information and known implementations, please see
# https://matrix-org.github.io/synapse/latest/password_auth_providers.html
#
# Note: instances wishing to use SAML or CAS authentication should
# instead use the `saml2_config` or `cas_config` options,
# respectively.
#
password_providers:
# # Example config for an LDAP auth provider
# - module: "ldap_auth_provider.LdapAuthProvider"
# config:
# enabled: true
# uri: "ldap://ldap.example.com:389"
# start_tls: true
# base: "ou=users,dc=example,dc=com"
# attributes:
# uid: "cn"
# mail: "email"
# name: "givenName"
# #bind_dn:
# #bind_password:
# #filter: "(objectClass=posixAccount)"
"""

View file

@ -200,46 +200,13 @@ class AuthHandler:
self.bcrypt_rounds = hs.config.registration.bcrypt_rounds
# we can't use hs.get_module_api() here, because to do so will create an
# import loop.
#
# TODO: refactor this class to separate the lower-level stuff that
# ModuleApi can use from the higher-level stuff that uses ModuleApi, as
# better way to break the loop
account_handler = ModuleApi(hs, self)
self.password_providers = [
PasswordProvider.load(module, config, account_handler)
for module, config in hs.config.authproviders.password_providers
]
logger.info("Extra password_providers: %s", self.password_providers)
self.password_auth_provider = hs.get_password_auth_provider()
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.auth.password_enabled
self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
# start out by assuming PASSWORD is enabled; we will remove it later if not.
login_types = set()
if self._password_localdb_enabled:
login_types.add(LoginType.PASSWORD)
for provider in self.password_providers:
login_types.update(provider.get_supported_login_types().keys())
if not self._password_enabled:
login_types.discard(LoginType.PASSWORD)
# Some clients just pick the first type in the list. In this case, we want
# them to use PASSWORD (rather than token or whatever), so we want to make sure
# that comes first, where it's present.
self._supported_login_types = []
if LoginType.PASSWORD in login_types:
self._supported_login_types.append(LoginType.PASSWORD)
login_types.remove(LoginType.PASSWORD)
self._supported_login_types.extend(login_types)
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
self._failed_uia_attempts_ratelimiter = Ratelimiter(
@ -427,8 +394,7 @@ class AuthHandler:
ui_auth_types.add(LoginType.PASSWORD)
# also allow auth from password providers
for provider in self.password_providers:
for t in provider.get_supported_login_types().keys():
for t in self.password_auth_provider.get_supported_login_types().keys():
if t == LoginType.PASSWORD and not self._password_enabled:
continue
ui_auth_types.add(t)
@ -1038,7 +1004,25 @@ class AuthHandler:
Returns:
login types
"""
return self._supported_login_types
# Load any login types registered by modules
# This is stored in the password_auth_provider so this doesn't trigger
# any callbacks
types = list(self.password_auth_provider.get_supported_login_types().keys())
# This list should include PASSWORD if (either _password_localdb_enabled is
# true or if one of the modules registered it) AND _password_enabled is true
# Also:
# Some clients just pick the first type in the list. In this case, we want
# them to use PASSWORD (rather than token or whatever), so we want to make sure
# that comes first, where it's present.
if LoginType.PASSWORD in types:
types.remove(LoginType.PASSWORD)
if self._password_enabled:
types.insert(0, LoginType.PASSWORD)
elif self._password_localdb_enabled and self._password_enabled:
types.insert(0, LoginType.PASSWORD)
return types
async def validate_login(
self,
@ -1217,15 +1201,20 @@ class AuthHandler:
known_login_type = False
for provider in self.password_providers:
supported_login_types = provider.get_supported_login_types()
if login_type not in supported_login_types:
# this password provider doesn't understand this login type
continue
# Check if login_type matches a type registered by one of the modules
# We don't need to remove LoginType.PASSWORD from the list if password login is
# disabled, since if that were the case then by this point we know that the
# login_type is not LoginType.PASSWORD
supported_login_types = self.password_auth_provider.get_supported_login_types()
# check if the login type being used is supported by a module
if login_type in supported_login_types:
# Make a note that this login type is supported by the server
known_login_type = True
# Get all the fields expected for this login types
login_fields = supported_login_types[login_type]
# go through the login submission and keep track of which required fields are
# provided/not provided
missing_fields = []
login_dict = {}
for f in login_fields:
@ -1233,6 +1222,7 @@ class AuthHandler:
missing_fields.append(f)
else:
login_dict[f] = login_submission[f]
# raise an error if any of the expected fields for that login type weren't provided
if missing_fields:
raise SynapseError(
400,
@ -1240,10 +1230,15 @@ class AuthHandler:
% (login_type, missing_fields),
)
result = await provider.check_auth(username, login_type, login_dict)
# call all of the check_auth hooks for that login_type
# it will return a result once the first success is found (or None otherwise)
result = await self.password_auth_provider.check_auth(
username, login_type, login_dict
)
if result:
return result
# if no module managed to authenticate the user, then fallback to built in password based auth
if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
known_login_type = True
@ -1282,11 +1277,16 @@ class AuthHandler:
completed login/registration, or `None`. If authentication was
unsuccessful, `user_id` and `callback` are both `None`.
"""
for provider in self.password_providers:
result = await provider.check_3pid_auth(medium, address, password)
# call all of the check_3pid_auth callbacks
# Result will be from the first callback that returns something other than None
# If all the callbacks return None, then result is also set to None
result = await self.password_auth_provider.check_3pid_auth(
medium, address, password
)
if result:
return result
# if result is None then return (None, None)
return None, None
async def _check_local_password(self, user_id: str, password: str) -> Optional[str]:
@ -1365,9 +1365,8 @@ class AuthHandler:
user_info = await self.auth.get_user_by_access_token(access_token)
await self.store.delete_access_token(access_token)
# see if any of our auth providers want to know about this
for provider in self.password_providers:
await provider.on_logged_out(
# see if any modules want to know about this
await self.password_auth_provider.on_logged_out(
user_id=user_info.user_id,
device_id=user_info.device_id,
access_token=access_token,
@ -1398,10 +1397,9 @@ class AuthHandler:
user_id, except_token_id=except_token_id, device_id=device_id
)
# see if any of our auth providers want to know about this
for provider in self.password_providers:
# see if any modules want to know about this
for token, _, device_id in tokens_and_devices:
await provider.on_logged_out(
await self.password_auth_provider.on_logged_out(
user_id=user_id, device_id=device_id, access_token=token
)
@ -1811,40 +1809,228 @@ class MacaroonGenerator:
return macaroon
class PasswordProvider:
"""Wrapper for a password auth provider module
def load_legacy_password_auth_providers(hs: "HomeServer") -> None:
module_api = hs.get_module_api()
for module, config in hs.config.authproviders.password_providers:
load_single_legacy_password_auth_provider(
module=module, config=config, api=module_api
)
This class abstracts out all of the backwards-compatibility hacks for
password providers, to provide a consistent interface.
"""
@classmethod
def load(
cls, module: Type, config: JsonDict, module_api: ModuleApi
) -> "PasswordProvider":
def load_single_legacy_password_auth_provider(
module: Type, config: JsonDict, api: ModuleApi
) -> None:
try:
pp = module(config=config, account_handler=module_api)
provider = module(config=config, account_handler=api)
except Exception as e:
logger.error("Error while initializing %r: %s", module, e)
raise
return cls(pp, module_api)
def __init__(self, pp: "PasswordProvider", module_api: ModuleApi):
self._pp = pp
self._module_api = module_api
# The known hooks. If a module implements a method who's name appears in this set
# we'll want to register it
password_auth_provider_methods = {
"check_3pid_auth",
"on_logged_out",
}
self._supported_login_types = {}
# All methods that the module provides should be async, but this wasn't enforced
# in the old module system, so we wrap them if needed
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
# f might be None if the callback isn't implemented by the module. In this
# case we don't want to register a callback at all so we return None.
if f is None:
return None
# grandfather in check_password support
if hasattr(self._pp, "check_password"):
self._supported_login_types[LoginType.PASSWORD] = ("password",)
# We need to wrap check_password because its old form would return a boolean
# but we now want it to behave just like check_auth() and return the matrix id of
# the user if authentication succeeded or None otherwise
if f.__name__ == "check_password":
g = getattr(self._pp, "get_supported_login_types", None)
if g:
self._supported_login_types.update(g())
async def wrapped_check_password(
username: str, login_type: str, login_dict: JsonDict
) -> Optional[Tuple[str, Optional[Callable]]]:
# We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not
# None.
assert f is not None
def __str__(self) -> str:
return str(self._pp)
matrix_user_id = api.get_qualified_user_id(username)
password = login_dict["password"]
is_valid = await f(matrix_user_id, password)
if is_valid:
return matrix_user_id, None
return None
return wrapped_check_password
# We need to wrap check_auth as in the old form it could return
# just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
if f.__name__ == "check_auth":
async def wrapped_check_auth(
username: str, login_type: str, login_dict: JsonDict
) -> Optional[Tuple[str, Optional[Callable]]]:
# We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not
# None.
assert f is not None
result = await f(username, login_type, login_dict)
if isinstance(result, str):
return result, None
return result
return wrapped_check_auth
# We need to wrap check_3pid_auth as in the old form it could return
# just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
if f.__name__ == "check_3pid_auth":
async def wrapped_check_3pid_auth(
medium: str, address: str, password: str
) -> Optional[Tuple[str, Optional[Callable]]]:
# We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not
# None.
assert f is not None
result = await f(medium, address, password)
if isinstance(result, str):
return result, None
return result
return wrapped_check_3pid_auth
def run(*args: Tuple, **kwargs: Dict) -> Awaitable:
# mypy doesn't do well across function boundaries so we need to tell it
# f is definitely not None.
assert f is not None
return maybe_awaitable(f(*args, **kwargs))
return run
# populate hooks with the implemented methods, wrapped with async_wrapper
hooks = {
hook: async_wrapper(getattr(provider, hook, None))
for hook in password_auth_provider_methods
}
supported_login_types = {}
# call get_supported_login_types and add that to the dict
g = getattr(provider, "get_supported_login_types", None)
if g is not None:
# Note the old module style also called get_supported_login_types at loading time
# and it is synchronous
supported_login_types.update(g())
auth_checkers = {}
# Legacy modules have a check_auth method which expects to be called with one of
# the keys returned by get_supported_login_types. New style modules register a
# dictionary of login_type->check_auth_method mappings
check_auth = async_wrapper(getattr(provider, "check_auth", None))
if check_auth is not None:
for login_type, fields in supported_login_types.items():
# need tuple(fields) since fields can be any Iterable type (so may not be hashable)
auth_checkers[(login_type, tuple(fields))] = check_auth
# if it has a "check_password" method then it should handle all auth checks
# with login type of LoginType.PASSWORD
check_password = async_wrapper(getattr(provider, "check_password", None))
if check_password is not None:
# need to use a tuple here for ("password",) not a list since lists aren't hashable
auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password
api.register_password_auth_provider_callbacks(hooks, auth_checkers=auth_checkers)
CHECK_3PID_AUTH_CALLBACK = Callable[
[str, str, str],
Awaitable[
Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
],
]
ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable]
CHECK_AUTH_CALLBACK = Callable[
[str, str, JsonDict],
Awaitable[
Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
],
]
class PasswordAuthProvider:
"""
A class that the AuthHandler calls when authenticating users
It allows modules to provide alternative methods for authentication
"""
def __init__(self) -> None:
# lists of callbacks
self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
# Mapping from login type to login parameters
self._supported_login_types: Dict[str, Iterable[str]] = {}
# Mapping from login type to auth checker callbacks
self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}
def register_password_auth_provider_callbacks(
self,
check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
auth_checkers: Optional[Dict[Tuple[str, Tuple], CHECK_AUTH_CALLBACK]] = None,
) -> None:
# Register check_3pid_auth callback
if check_3pid_auth is not None:
self.check_3pid_auth_callbacks.append(check_3pid_auth)
# register on_logged_out callback
if on_logged_out is not None:
self.on_logged_out_callbacks.append(on_logged_out)
if auth_checkers is not None:
# register a new supported login_type
# Iterate through all of the types being registered
for (login_type, fields), callback in auth_checkers.items():
# Note: fields may be empty here. This would allow a modules auth checker to
# be called with just 'login_type' and no password or other secrets
# Need to check that all the field names are strings or may get nasty errors later
for f in fields:
if not isinstance(f, str):
raise RuntimeError(
"A module tried to register support for login type: %s with parameters %s"
" but all parameter names must be strings"
% (login_type, fields)
)
# 2 modules supporting the same login type must expect the same fields
# e.g. 1 can't expect "pass" if the other expects "password"
# so throw an exception if that happens
if login_type not in self._supported_login_types.get(login_type, []):
self._supported_login_types[login_type] = fields
else:
fields_currently_supported = self._supported_login_types.get(
login_type
)
if fields_currently_supported != fields:
raise RuntimeError(
"A module tried to register support for login type: %s with parameters %s"
" but another module had already registered support for that type with parameters %s"
% (login_type, fields, fields_currently_supported)
)
# Add the new method to the list of auth_checker_callbacks for this login type
self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
"""Get the login types supported by this password provider
@ -1852,20 +2038,15 @@ class PasswordProvider:
Returns a map from a login type identifier (such as m.login.password) to an
iterable giving the fields which must be provided by the user in the submission
to the /login API.
This wrapper adds m.login.password to the list if the underlying password
provider supports the check_password() api.
"""
return self._supported_login_types
async def check_auth(
self, username: str, login_type: str, login_dict: JsonDict
) -> Optional[Tuple[str, Optional[Callable]]]:
) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
"""Check if the user has presented valid login credentials
This wrapper also calls check_password() if the underlying password provider
supports the check_password() api and the login type is m.login.password.
Args:
username: user id presented by the client. Either an MXID or an unqualified
username.
@ -1879,63 +2060,130 @@ class PasswordProvider:
user, and `callback` is an optional callback which will be called with the
result from the /login call (including access_token, device_id, etc.)
"""
# first grandfather in a call to check_password
if login_type == LoginType.PASSWORD:
check_password = getattr(self._pp, "check_password", None)
if check_password:
qualified_user_id = self._module_api.get_qualified_user_id(username)
is_valid = await check_password(
qualified_user_id, login_dict["password"]
# Go through all callbacks for the login type until one returns with a value
# other than None (i.e. until a callback returns a success)
for callback in self.auth_checker_callbacks[login_type]:
try:
result = await callback(username, login_type, login_dict)
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue
if result is not None:
# Check that the callback returned a Tuple[str, Optional[Callable]]
# "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
# result is always the right type, but as it is 3rd party code it might not be
if not isinstance(result, tuple) or len(result) != 2:
logger.warning(
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
result,
)
if is_valid:
return qualified_user_id, None
continue
check_auth = getattr(self._pp, "check_auth", None)
if not check_auth:
return None
result = await check_auth(username, login_type, login_dict)
# pull out the two parts of the tuple so we can do type checking
str_result, callback_result = result
# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
return result, None
# the 1st item in the tuple should be a str
if not isinstance(str_result, str):
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
result,
)
continue
# the second should be Optional[Callable]
if callback_result is not None:
if not callable(callback_result):
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
result,
)
continue
# The result is a (str, Optional[callback]) tuple so return the successful result
return result
# If this point has been reached then none of the callbacks successfully authenticated
# the user so return None
return None
async def check_3pid_auth(
self, medium: str, address: str, password: str
) -> Optional[Tuple[str, Optional[Callable]]]:
g = getattr(self._pp, "check_3pid_auth", None)
if not g:
return None
) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
# This function is able to return a deferred that either
# resolves None, meaning authentication failure, or upon
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
result = await g(medium, address, password)
# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
return result, None
for callback in self.check_3pid_auth_callbacks:
try:
result = await callback(medium, address, password)
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue
if result is not None:
# Check that the callback returned a Tuple[str, Optional[Callable]]
# "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
# result is always the right type, but as it is 3rd party code it might not be
if not isinstance(result, tuple) or len(result) != 2:
logger.warning(
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
result,
)
continue
# pull out the two parts of the tuple so we can do type checking
str_result, callback_result = result
# the 1st item in the tuple should be a str
if not isinstance(str_result, str):
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
result,
)
continue
# the second should be Optional[Callable]
if callback_result is not None:
if not callable(callback_result):
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
result,
)
continue
# The result is a (str, Optional[callback]) tuple so return the successful result
return result
# If this point has been reached then none of the callbacks successfully authenticated
# the user so return None
return None
async def on_logged_out(
self, user_id: str, device_id: Optional[str], access_token: str
) -> None:
g = getattr(self._pp, "on_logged_out", None)
if not g:
return
# This might return an awaitable, if it does block the log out
# until it completes.
await maybe_awaitable(
g(
user_id=user_id,
device_id=device_id,
access_token=access_token,
)
)
# call all of the on_logged_out callbacks
for callback in self.on_logged_out_callbacks:
try:
callback(user_id, device_id, access_token)
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue

View file

@ -45,6 +45,7 @@ from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.client.login import LoginResponse
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.roommember import ProfileInfo
from synapse.storage.state import StateFilter
@ -83,6 +84,8 @@ __all__ = [
"DirectServeJsonResource",
"ModuleApi",
"PRESENCE_ALL_USERS",
"LoginResponse",
"JsonDict",
]
logger = logging.getLogger(__name__)
@ -139,6 +142,7 @@ class ModuleApi:
self._spam_checker = hs.get_spam_checker()
self._account_validity_handler = hs.get_account_validity_handler()
self._third_party_event_rules = hs.get_third_party_event_rules()
self._password_auth_provider = hs.get_password_auth_provider()
self._presence_router = hs.get_presence_router()
#################################################################################
@ -164,6 +168,11 @@ class ModuleApi:
"""Registers callbacks for presence router capabilities."""
return self._presence_router.register_presence_router_callbacks
@property
def register_password_auth_provider_callbacks(self):
"""Registers callbacks for password auth provider capabilities."""
return self._password_auth_provider.register_password_auth_provider_callbacks
def register_web_resource(self, path: str, resource: IResource):
"""Registers a web resource to be served at the given path.

View file

@ -65,7 +65,7 @@ from synapse.handlers.account_data import AccountDataHandler
from synapse.handlers.account_validity import AccountValidityHandler
from synapse.handlers.admin import AdminHandler
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler, MacaroonGenerator
from synapse.handlers.auth import AuthHandler, MacaroonGenerator, PasswordAuthProvider
from synapse.handlers.cas import CasHandler
from synapse.handlers.deactivate_account import DeactivateAccountHandler
from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
@ -687,6 +687,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_third_party_event_rules(self) -> ThirdPartyEventRules:
return ThirdPartyEventRules(self)
@cache_in_self
def get_password_auth_provider(self) -> PasswordAuthProvider:
return PasswordAuthProvider()
@cache_in_self
def get_room_member_handler(self) -> RoomMemberHandler:
if self.config.worker.worker_app:

View file

@ -549,6 +549,8 @@ def _apply_module_schemas(
database_engine:
config: application config
"""
# This is the old way for password_auth_provider modules to make changes
# to the database. This should instead be done using the module API
for (mod, _config) in config.authproviders.password_providers:
if not hasattr(mod, "get_db_schema_files"):
continue

View file

@ -20,6 +20,8 @@ from unittest.mock import Mock
from twisted.internet import defer
import synapse
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.module_api import ModuleApi
from synapse.rest.client import devices, login
from synapse.types import JsonDict
@ -36,8 +38,8 @@ ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_servi
mock_password_provider = Mock()
class PasswordOnlyAuthProvider:
"""A password_provider which only implements `check_password`."""
class LegacyPasswordOnlyAuthProvider:
"""A legacy password_provider which only implements `check_password`."""
@staticmethod
def parse_config(self):
@ -50,8 +52,8 @@ class PasswordOnlyAuthProvider:
return mock_password_provider.check_password(*args)
class CustomAuthProvider:
"""A password_provider which implements a custom login type."""
class LegacyCustomAuthProvider:
"""A legacy password_provider which implements a custom login type."""
@staticmethod
def parse_config(self):
@ -67,7 +69,23 @@ class CustomAuthProvider:
return mock_password_provider.check_auth(*args)
class PasswordCustomAuthProvider:
class CustomAuthProvider:
"""A module which registers password_auth_provider callbacks for a custom login type."""
@staticmethod
def parse_config(self):
pass
def __init__(self, config, api: ModuleApi):
api.register_password_auth_provider_callbacks(
auth_checkers={("test.login_type", ("test_field",)): self.check_auth},
)
def check_auth(self, *args):
return mock_password_provider.check_auth(*args)
class LegacyPasswordCustomAuthProvider:
"""A password_provider which implements password login via `check_auth`, as well
as a custom type."""
@ -85,8 +103,32 @@ class PasswordCustomAuthProvider:
return mock_password_provider.check_auth(*args)
def providers_config(*providers: Type[Any]) -> dict:
"""Returns a config dict that will enable the given password auth providers"""
class PasswordCustomAuthProvider:
"""A module which registers password_auth_provider callbacks for a custom login type.
as well as a password login"""
@staticmethod
def parse_config(self):
pass
def __init__(self, config, api: ModuleApi):
api.register_password_auth_provider_callbacks(
auth_checkers={
("test.login_type", ("test_field",)): self.check_auth,
("m.login.password", ("password",)): self.check_auth,
},
)
pass
def check_auth(self, *args):
return mock_password_provider.check_auth(*args)
def check_pass(self, *args):
return mock_password_provider.check_password(*args)
def legacy_providers_config(*providers: Type[Any]) -> dict:
"""Returns a config dict that will enable the given legacy password auth providers"""
return {
"password_providers": [
{"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
@ -95,6 +137,16 @@ def providers_config(*providers: Type[Any]) -> dict:
}
def providers_config(*providers: Type[Any]) -> dict:
"""Returns a config dict that will enable the given modules"""
return {
"modules": [
{"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
for provider in providers
]
}
class PasswordAuthProviderTests(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
@ -107,8 +159,21 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
super().setUp()
@override_config(providers_config(PasswordOnlyAuthProvider))
def test_password_only_auth_provider_login(self):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver()
# Load the modules into the homeserver
module_api = hs.get_module_api()
for module, config in hs.config.modules.loaded_modules:
module(config=config, api=module_api)
load_legacy_password_auth_providers(hs)
return hs
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_password_only_auth_progiver_login_legacy(self):
self.password_only_auth_provider_login_test_body()
def password_only_auth_provider_login_test_body(self):
# login flows should only have m.login.password
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
@ -138,8 +203,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"@ USER🙂NAME :test", " pASS😢word "
)
@override_config(providers_config(PasswordOnlyAuthProvider))
def test_password_only_auth_provider_ui_auth(self):
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_password_only_auth_provider_ui_auth_legacy(self):
self.password_only_auth_provider_ui_auth_test_body()
def password_only_auth_provider_ui_auth_test_body(self):
"""UI Auth should delegate correctly to the password provider"""
# create the user, otherwise access doesn't work
@ -172,8 +240,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
@override_config(providers_config(PasswordOnlyAuthProvider))
def test_local_user_fallback_login(self):
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_local_user_fallback_login_legacy(self):
self.local_user_fallback_login_test_body()
def local_user_fallback_login_test_body(self):
"""rejected login should fall back to local db"""
self.register_user("localuser", "localpass")
@ -186,8 +257,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@localuser:test", channel.json_body["user_id"])
@override_config(providers_config(PasswordOnlyAuthProvider))
def test_local_user_fallback_ui_auth(self):
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_local_user_fallback_ui_auth_legacy(self):
self.local_user_fallback_ui_auth_test_body()
def local_user_fallback_ui_auth_test_body(self):
"""rejected login should fall back to local db"""
self.register_user("localuser", "localpass")
@ -223,11 +297,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config(
{
**providers_config(PasswordOnlyAuthProvider),
**legacy_providers_config(LegacyPasswordOnlyAuthProvider),
"password_config": {"localdb_enabled": False},
}
)
def test_no_local_user_fallback_login(self):
def test_no_local_user_fallback_login_legacy(self):
self.no_local_user_fallback_login_test_body()
def no_local_user_fallback_login_test_body(self):
"""localdb_enabled can block login with the local password"""
self.register_user("localuser", "localpass")
@ -242,11 +319,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config(
{
**providers_config(PasswordOnlyAuthProvider),
**legacy_providers_config(LegacyPasswordOnlyAuthProvider),
"password_config": {"localdb_enabled": False},
}
)
def test_no_local_user_fallback_ui_auth(self):
def test_no_local_user_fallback_ui_auth_legacy(self):
self.no_local_user_fallback_ui_auth_test_body()
def no_local_user_fallback_ui_auth_test_body(self):
"""localdb_enabled can block ui auth with the local password"""
self.register_user("localuser", "localpass")
@ -280,11 +360,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config(
{
**providers_config(PasswordOnlyAuthProvider),
**legacy_providers_config(LegacyPasswordOnlyAuthProvider),
"password_config": {"enabled": False},
}
)
def test_password_auth_disabled(self):
def test_password_auth_disabled_legacy(self):
self.password_auth_disabled_test_body()
def password_auth_disabled_test_body(self):
"""password auth doesn't work if it's disabled across the board"""
# login flows should be empty
flows = self._get_login_flows()
@ -295,8 +378,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_password.assert_not_called()
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
def test_custom_auth_provider_login_legacy(self):
self.custom_auth_provider_login_test_body()
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_login(self):
self.custom_auth_provider_login_test_body()
def custom_auth_provider_login_test_body(self):
# login flows should have the custom flow and m.login.password, since we
# haven't disabled local password lookup.
# (password must come first, because reasons)
@ -312,7 +402,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
mock_password_provider.check_auth.return_value = defer.succeed(
("@user:bz", None)
)
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"])
@ -325,7 +417,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# in these cases, but at least we can guard against the API changing
# unexpectedly
mock_password_provider.check_auth.return_value = defer.succeed(
"@ MALFORMED! :bz"
("@ MALFORMED! :bz", None)
)
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
self.assertEqual(channel.code, 200, channel.result)
@ -334,8 +426,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
" USER🙂NAME ", "test.login_type", {"test_field": " abc "}
)
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
def test_custom_auth_provider_ui_auth_legacy(self):
self.custom_auth_provider_ui_auth_test_body()
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_ui_auth(self):
self.custom_auth_provider_ui_auth_test_body()
def custom_auth_provider_ui_auth_test_body(self):
# register the user and log in twice, to get two devices
self.register_user("localuser", "localpass")
tok1 = self.login("localuser", "localpass")
@ -367,7 +466,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
# right params, but authing as the wrong user
mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
mock_password_provider.check_auth.return_value = defer.succeed(
("@user:bz", None)
)
body["auth"]["test_field"] = "foo"
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 403)
@ -379,7 +480,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# and finally, succeed
mock_password_provider.check_auth.return_value = defer.succeed(
"@localuser:test"
("@localuser:test", None)
)
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 200)
@ -387,8 +488,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"localuser", "test.login_type", {"test_field": "foo"}
)
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
def test_custom_auth_provider_callback_legacy(self):
self.custom_auth_provider_callback_test_body()
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_callback(self):
self.custom_auth_provider_callback_test_body()
def custom_auth_provider_callback_test_body(self):
callback = Mock(return_value=defer.succeed(None))
mock_password_provider.check_auth.return_value = defer.succeed(
@ -410,10 +518,22 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
for p in ["user_id", "access_token", "device_id", "home_server"]:
self.assertIn(p, call_args[0])
@override_config(
{
**legacy_providers_config(LegacyCustomAuthProvider),
"password_config": {"enabled": False},
}
)
def test_custom_auth_password_disabled_legacy(self):
self.custom_auth_password_disabled_test_body()
@override_config(
{**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
)
def test_custom_auth_password_disabled(self):
self.custom_auth_password_disabled_test_body()
def custom_auth_password_disabled_test_body(self):
"""Test login with a custom auth provider where password login is disabled"""
self.register_user("localuser", "localpass")
@ -425,6 +545,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
@override_config(
{
**legacy_providers_config(LegacyCustomAuthProvider),
"password_config": {"enabled": False, "localdb_enabled": False},
}
)
def test_custom_auth_password_disabled_localdb_enabled_legacy(self):
self.custom_auth_password_disabled_localdb_enabled_test_body()
@override_config(
{
**providers_config(CustomAuthProvider),
@ -432,6 +561,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
}
)
def test_custom_auth_password_disabled_localdb_enabled(self):
self.custom_auth_password_disabled_localdb_enabled_test_body()
def custom_auth_password_disabled_localdb_enabled_test_body(self):
"""Check the localdb_enabled == enabled == False
Regression test for https://github.com/matrix-org/synapse/issues/8914: check
@ -448,6 +580,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
@override_config(
{
**legacy_providers_config(LegacyPasswordCustomAuthProvider),
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_login_legacy(self):
self.password_custom_auth_password_disabled_login_test_body()
@override_config(
{
**providers_config(PasswordCustomAuthProvider),
@ -455,6 +596,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
}
)
def test_password_custom_auth_password_disabled_login(self):
self.password_custom_auth_password_disabled_login_test_body()
def password_custom_auth_password_disabled_login_test_body(self):
"""log in with a custom auth provider which implements password, but password
login is disabled"""
self.register_user("localuser", "localpass")
@ -466,6 +610,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_password.assert_not_called()
@override_config(
{
**legacy_providers_config(LegacyPasswordCustomAuthProvider),
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_ui_auth_legacy(self):
self.password_custom_auth_password_disabled_ui_auth_test_body()
@override_config(
{
@ -474,12 +628,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
}
)
def test_password_custom_auth_password_disabled_ui_auth(self):
self.password_custom_auth_password_disabled_ui_auth_test_body()
def password_custom_auth_password_disabled_ui_auth_test_body(self):
"""UI Auth with a custom auth provider which implements password, but password
login is disabled"""
# register the user and log in twice via the test login type to get two devices,
self.register_user("localuser", "localpass")
mock_password_provider.check_auth.return_value = defer.succeed(
"@localuser:test"
("@localuser:test", None)
)
channel = self._send_login("test.login_type", "localuser", test_field="")
self.assertEqual(channel.code, 200, channel.result)
@ -516,6 +673,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"Password login has been disabled.", channel.json_body["error"]
)
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_password.assert_not_called()
mock_password_provider.reset_mock()
# successful auth
@ -526,6 +684,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.check_auth.assert_called_once_with(
"localuser", "test.login_type", {"test_field": "x"}
)
mock_password_provider.check_password.assert_not_called()
@override_config(
{
**legacy_providers_config(LegacyCustomAuthProvider),
"password_config": {"localdb_enabled": False},
}
)
def test_custom_auth_no_local_user_fallback_legacy(self):
self.custom_auth_no_local_user_fallback_test_body()
@override_config(
{
@ -534,6 +702,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
}
)
def test_custom_auth_no_local_user_fallback(self):
self.custom_auth_no_local_user_fallback_test_body()
def custom_auth_no_local_user_fallback_test_body(self):
"""Test login with a custom auth provider where the local db is disabled"""
self.register_user("localuser", "localpass")