ModuleAPI SSO auth callbacks (#15207)

Signed-off-by: Andrii Yasynyshyn yasinishyn.a.n@gmail.com
This commit is contained in:
Andrew Yasinishyn 2023-12-01 16:31:50 +02:00 committed by GitHub
parent 579c6be5f6
commit 63d96bfc61
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 56 additions and 2 deletions

View file

@ -0,0 +1 @@
Adds on_user_login ModuleAPI callback allowing to execute custom code after (on) Auth.

View file

@ -42,3 +42,16 @@ operations to keep track of them. (e.g. add them to a database table). The user
represented by their Matrix user ID. represented by their Matrix user ID.
If multiple modules implement this callback, Synapse runs them all in order. If multiple modules implement this callback, Synapse runs them all in order.
### `on_user_login`
_First introduced in Synapse v1.98.0_
```python
async def on_user_login(user_id: str, auth_provider_type: str, auth_provider_id: str) -> None
```
Called after successfully login or registration of a user for cases when module needs to perform extra operations after auth.
represented by their Matrix user ID.
If multiple modules implement this callback, Synapse runs them all in order.

View file

@ -296,8 +296,7 @@ impl<'source> FromPyObject<'source> for JsonValue {
match l.iter().map(SimpleJsonValue::extract).collect() { match l.iter().map(SimpleJsonValue::extract).collect() {
Ok(a) => Ok(JsonValue::Array(a)), Ok(a) => Ok(JsonValue::Array(a)),
Err(e) => Err(PyTypeError::new_err(format!( Err(e) => Err(PyTypeError::new_err(format!(
"Can't convert to JsonValue::Array: {}", "Can't convert to JsonValue::Array: {e}"
e
))), ))),
} }
} else if let Ok(v) = SimpleJsonValue::extract(ob) { } else if let Ok(v) = SimpleJsonValue::extract(ob) {

View file

@ -98,6 +98,22 @@ class AccountValidityHandler:
for callback in self._module_api_callbacks.on_user_registration_callbacks: for callback in self._module_api_callbacks.on_user_registration_callbacks:
await callback(user_id) await callback(user_id)
async def on_user_login(
self,
user_id: str,
auth_provider_type: Optional[str],
auth_provider_id: Optional[str],
) -> None:
"""Tell third-party modules about a user logins.
Args:
user_id: The mxID of the user.
auth_provider_type: The type of login.
auth_provider_id: The ID of the auth provider.
"""
for callback in self._module_api_callbacks.on_user_login_callbacks:
await callback(user_id, auth_provider_type, auth_provider_id)
@wrap_as_background_process("send_renewals") @wrap_as_background_process("send_renewals")
async def _send_renewal_emails(self) -> None: async def _send_renewal_emails(self) -> None:
"""Gets the list of users whose account is expiring in the amount of time """Gets the list of users whose account is expiring in the amount of time

View file

@ -212,6 +212,7 @@ class AuthHandler:
self._password_enabled_for_reauth = hs.config.auth.password_enabled_for_reauth self._password_enabled_for_reauth = hs.config.auth.password_enabled_for_reauth
self._password_localdb_enabled = hs.config.auth.password_localdb_enabled self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules
self._account_validity_handler = hs.get_account_validity_handler()
# Ratelimiter for failed auth during UIA. Uses same ratelimit config # Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`. # as per `rc_login.failed_attempts`.
@ -1783,6 +1784,13 @@ class AuthHandler:
client_redirect_url, "loginToken", login_token client_redirect_url, "loginToken", login_token
) )
# Run post-login module callback handlers
await self._account_validity_handler.on_user_login(
user_id=registered_user_id,
auth_provider_type=LoginType.SSO,
auth_provider_id=auth_provider_id,
)
# if the client is whitelisted, we can redirect straight to it # if the client is whitelisted, we can redirect straight to it
if client_redirect_url.startswith(self._whitelisted_sso_clients): if client_redirect_url.startswith(self._whitelisted_sso_clients):
request.redirect(redirect_url) request.redirect(redirect_url)

View file

@ -80,6 +80,7 @@ from synapse.module_api.callbacks.account_validity_callbacks import (
ON_LEGACY_ADMIN_REQUEST, ON_LEGACY_ADMIN_REQUEST,
ON_LEGACY_RENEW_CALLBACK, ON_LEGACY_RENEW_CALLBACK,
ON_LEGACY_SEND_MAIL_CALLBACK, ON_LEGACY_SEND_MAIL_CALLBACK,
ON_USER_LOGIN_CALLBACK,
ON_USER_REGISTRATION_CALLBACK, ON_USER_REGISTRATION_CALLBACK,
) )
from synapse.module_api.callbacks.spamchecker_callbacks import ( from synapse.module_api.callbacks.spamchecker_callbacks import (
@ -334,6 +335,7 @@ class ModuleApi:
*, *,
is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None, is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None,
on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None, on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None,
on_user_login: Optional[ON_USER_LOGIN_CALLBACK] = None,
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None, on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None, on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None, on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
@ -345,6 +347,7 @@ class ModuleApi:
return self._callbacks.account_validity.register_callbacks( return self._callbacks.account_validity.register_callbacks(
is_user_expired=is_user_expired, is_user_expired=is_user_expired,
on_user_registration=on_user_registration, on_user_registration=on_user_registration,
on_user_login=on_user_login,
on_legacy_send_mail=on_legacy_send_mail, on_legacy_send_mail=on_legacy_send_mail,
on_legacy_renew=on_legacy_renew, on_legacy_renew=on_legacy_renew,
on_legacy_admin_request=on_legacy_admin_request, on_legacy_admin_request=on_legacy_admin_request,

View file

@ -22,6 +22,7 @@ logger = logging.getLogger(__name__)
# Types for callbacks to be registered via the module api # Types for callbacks to be registered via the module api
IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]] IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]]
ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable] ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable]
ON_USER_LOGIN_CALLBACK = Callable[[str, Optional[str], Optional[str]], Awaitable]
# Temporary hooks to allow for a transition from `/_matrix/client` endpoints # Temporary hooks to allow for a transition from `/_matrix/client` endpoints
# to `/_synapse/client/account_validity`. See `register_callbacks` below. # to `/_synapse/client/account_validity`. See `register_callbacks` below.
ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable] ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable]
@ -33,6 +34,7 @@ class AccountValidityModuleApiCallbacks:
def __init__(self) -> None: def __init__(self) -> None:
self.is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = [] self.is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = []
self.on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = [] self.on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = []
self.on_user_login_callbacks: List[ON_USER_LOGIN_CALLBACK] = []
self.on_legacy_send_mail_callback: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None self.on_legacy_send_mail_callback: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None
self.on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None self.on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None
@ -44,6 +46,7 @@ class AccountValidityModuleApiCallbacks:
self, self,
is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None, is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None,
on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None, on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None,
on_user_login: Optional[ON_USER_LOGIN_CALLBACK] = None,
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None, on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None, on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None, on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
@ -55,6 +58,9 @@ class AccountValidityModuleApiCallbacks:
if on_user_registration is not None: if on_user_registration is not None:
self.on_user_registration_callbacks.append(on_user_registration) self.on_user_registration_callbacks.append(on_user_registration)
if on_user_login is not None:
self.on_user_login_callbacks.append(on_user_login)
# The builtin account validity feature exposes 3 endpoints (send_mail, renew, and # The builtin account validity feature exposes 3 endpoints (send_mail, renew, and
# an admin one). As part of moving the feature into a module, we need to change # an admin one). As part of moving the feature into a module, we need to change
# the path from /_matrix/client/unstable/account_validity/... to # the path from /_matrix/client/unstable/account_validity/... to

View file

@ -115,6 +115,7 @@ class LoginRestServlet(RestServlet):
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
self._spam_checker = hs.get_module_api_callbacks().spam_checker self._spam_checker = hs.get_module_api_callbacks().spam_checker
self._account_validity_handler = hs.get_account_validity_handler()
self._well_known_builder = WellKnownBuilder(hs) self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter( self._address_ratelimiter = Ratelimiter(
@ -470,6 +471,13 @@ class LoginRestServlet(RestServlet):
device_id=device_id, device_id=device_id,
) )
# execute the callback
await self._account_validity_handler.on_user_login(
user_id,
auth_provider_type=login_submission.get("type"),
auth_provider_id=auth_provider_id,
)
if valid_until_ms is not None: if valid_until_ms is not None:
expires_in_ms = valid_until_ms - self.clock.time_msec() expires_in_ms = valid_until_ms - self.clock.time_msec()
result["expires_in_ms"] = expires_in_ms result["expires_in_ms"] = expires_in_ms