mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-24 18:45:52 +03:00
JWT OIDC secrets for Sign in with Apple (#9549)
Apple had to be special. They want a client secret which is generated from an EC key. Fixes #9220. Also fixes #9212 while I'm here.
This commit is contained in:
parent
9cd18cc588
commit
eaada74075
11 changed files with 444 additions and 47 deletions
|
@ -20,9 +20,10 @@ recursive-include scripts *
|
||||||
recursive-include scripts-dev *
|
recursive-include scripts-dev *
|
||||||
recursive-include synapse *.pyi
|
recursive-include synapse *.pyi
|
||||||
recursive-include tests *.py
|
recursive-include tests *.py
|
||||||
include tests/http/ca.crt
|
recursive-include tests *.pem
|
||||||
include tests/http/ca.key
|
recursive-include tests *.p8
|
||||||
include tests/http/server.key
|
recursive-include tests *.crt
|
||||||
|
recursive-include tests *.key
|
||||||
|
|
||||||
recursive-include synapse/res *
|
recursive-include synapse/res *
|
||||||
recursive-include synapse/static *.css
|
recursive-include synapse/static *.css
|
||||||
|
|
1
changelog.d/9549.feature
Normal file
1
changelog.d/9549.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add support for generating JSON Web Tokens dynamically for use as OIDC client secrets.
|
|
@ -386,7 +386,7 @@ oidc_providers:
|
||||||
config:
|
config:
|
||||||
subject_claim: "id"
|
subject_claim: "id"
|
||||||
localpart_template: "{{ user.login }}"
|
localpart_template: "{{ user.login }}"
|
||||||
display_name_template: "{{ user.full_name }}"
|
display_name_template: "{{ user.full_name }}"
|
||||||
```
|
```
|
||||||
|
|
||||||
### XWiki
|
### XWiki
|
||||||
|
@ -401,8 +401,7 @@ oidc_providers:
|
||||||
idp_name: "XWiki"
|
idp_name: "XWiki"
|
||||||
issuer: "https://myxwikihost/xwiki/oidc/"
|
issuer: "https://myxwikihost/xwiki/oidc/"
|
||||||
client_id: "your-client-id" # TO BE FILLED
|
client_id: "your-client-id" # TO BE FILLED
|
||||||
# Needed until https://github.com/matrix-org/synapse/issues/9212 is fixed
|
client_auth_method: none
|
||||||
client_secret: "dontcare"
|
|
||||||
scopes: ["openid", "profile"]
|
scopes: ["openid", "profile"]
|
||||||
user_profile_method: "userinfo_endpoint"
|
user_profile_method: "userinfo_endpoint"
|
||||||
user_mapping_provider:
|
user_mapping_provider:
|
||||||
|
@ -410,3 +409,40 @@ oidc_providers:
|
||||||
localpart_template: "{{ user.preferred_username }}"
|
localpart_template: "{{ user.preferred_username }}"
|
||||||
display_name_template: "{{ user.name }}"
|
display_name_template: "{{ user.name }}"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Apple
|
||||||
|
|
||||||
|
Configuring "Sign in with Apple" (SiWA) requires an Apple Developer account.
|
||||||
|
|
||||||
|
You will need to create a new "Services ID" for SiWA, and create and download a
|
||||||
|
private key with "SiWA" enabled.
|
||||||
|
|
||||||
|
As well as the private key file, you will need:
|
||||||
|
* Client ID: the "identifier" you gave the "Services ID"
|
||||||
|
* Team ID: a 10-character ID associated with your developer account.
|
||||||
|
* Key ID: the 10-character identifier for the key.
|
||||||
|
|
||||||
|
https://help.apple.com/developer-account/?lang=en#/dev77c875b7e has more
|
||||||
|
documentation on setting up SiWA.
|
||||||
|
|
||||||
|
The synapse config will look like this:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
- idp_id: apple
|
||||||
|
idp_name: Apple
|
||||||
|
issuer: "https://appleid.apple.com"
|
||||||
|
client_id: "your-client-id" # Set to the "identifier" for your "ServicesID"
|
||||||
|
client_auth_method: "client_secret_post"
|
||||||
|
client_secret_jwt_key:
|
||||||
|
key_file: "/path/to/AuthKey_KEYIDCODE.p8" # point to your key file
|
||||||
|
jwt_header:
|
||||||
|
alg: ES256
|
||||||
|
kid: "KEYIDCODE" # Set to the 10-char Key ID
|
||||||
|
jwt_payload:
|
||||||
|
iss: TEAMIDCODE # Set to the 10-char Team ID
|
||||||
|
scopes: ["name", "email", "openid"]
|
||||||
|
authorization_endpoint: https://appleid.apple.com/auth/authorize?response_mode=form_post
|
||||||
|
user_mapping_provider:
|
||||||
|
config:
|
||||||
|
email_template: "{{ user.email }}"
|
||||||
|
```
|
||||||
|
|
|
@ -1779,7 +1779,26 @@ saml2_config:
|
||||||
#
|
#
|
||||||
# client_id: Required. oauth2 client id to use.
|
# client_id: Required. oauth2 client id to use.
|
||||||
#
|
#
|
||||||
# client_secret: Required. oauth2 client secret to use.
|
# client_secret: oauth2 client secret to use. May be omitted if
|
||||||
|
# client_secret_jwt_key is given, or if client_auth_method is 'none'.
|
||||||
|
#
|
||||||
|
# client_secret_jwt_key: Alternative to client_secret: details of a key used
|
||||||
|
# to create a JSON Web Token to be used as an OAuth2 client secret. If
|
||||||
|
# given, must be a dictionary with the following properties:
|
||||||
|
#
|
||||||
|
# key: a pem-encoded signing key. Must be a suitable key for the
|
||||||
|
# algorithm specified. Required unless 'key_file' is given.
|
||||||
|
#
|
||||||
|
# key_file: the path to file containing a pem-encoded signing key file.
|
||||||
|
# Required unless 'key' is given.
|
||||||
|
#
|
||||||
|
# jwt_header: a dictionary giving properties to include in the JWT
|
||||||
|
# header. Must include the key 'alg', giving the algorithm used to
|
||||||
|
# sign the JWT, such as "ES256", using the JWA identifiers in
|
||||||
|
# RFC7518.
|
||||||
|
#
|
||||||
|
# jwt_payload: an optional dictionary giving properties to include in
|
||||||
|
# the JWT payload. Normally this should include an 'iss' key.
|
||||||
#
|
#
|
||||||
# client_auth_method: auth method to use when exchanging the token. Valid
|
# client_auth_method: auth method to use when exchanging the token. Valid
|
||||||
# values are 'client_secret_basic' (default), 'client_secret_post' and
|
# values are 'client_secret_basic' (default), 'client_secret_post' and
|
||||||
|
|
|
@ -212,9 +212,8 @@ class Config:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def read_file(cls, file_path, config_name):
|
def read_file(cls, file_path, config_name):
|
||||||
cls.check_file(file_path, config_name)
|
"""Deprecated: call read_file directly"""
|
||||||
with open(file_path) as file_stream:
|
return read_file(file_path, (config_name,))
|
||||||
return file_stream.read()
|
|
||||||
|
|
||||||
def read_template(self, filename: str) -> jinja2.Template:
|
def read_template(self, filename: str) -> jinja2.Template:
|
||||||
"""Load a template file from disk.
|
"""Load a template file from disk.
|
||||||
|
@ -894,4 +893,35 @@ class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
|
||||||
return self._get_instance(key)
|
return self._get_instance(key)
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
|
def read_file(file_path: Any, config_path: Iterable[str]) -> str:
|
||||||
|
"""Check the given file exists, and read it into a string
|
||||||
|
|
||||||
|
If it does not, emit an error indicating the problem
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: the file to be read
|
||||||
|
config_path: where in the configuration file_path came from, so that a useful
|
||||||
|
error can be emitted if it does not exist.
|
||||||
|
Returns:
|
||||||
|
content of the file.
|
||||||
|
Raises:
|
||||||
|
ConfigError if there is a problem reading the file.
|
||||||
|
"""
|
||||||
|
if not isinstance(file_path, str):
|
||||||
|
raise ConfigError("%r is not a string", config_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.stat(file_path)
|
||||||
|
with open(file_path) as file_stream:
|
||||||
|
return file_stream.read()
|
||||||
|
except OSError as e:
|
||||||
|
raise ConfigError("Error accessing file %r" % (file_path,), config_path) from e
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Config",
|
||||||
|
"RootConfig",
|
||||||
|
"ShardedWorkerHandlingConfig",
|
||||||
|
"RoutableShardedWorkerHandlingConfig",
|
||||||
|
"read_file",
|
||||||
|
]
|
||||||
|
|
|
@ -152,3 +152,5 @@ class ShardedWorkerHandlingConfig:
|
||||||
|
|
||||||
class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
|
class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
|
||||||
def get_instance(self, key: str) -> str: ...
|
def get_instance(self, key: str) -> str: ...
|
||||||
|
|
||||||
|
def read_file(file_path: Any, config_path: Iterable[str]) -> str: ...
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import Iterable, Optional, Tuple, Type
|
from typing import Iterable, Mapping, Optional, Tuple, Type
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ from synapse.types import Collection, JsonDict
|
||||||
from synapse.util.module_loader import load_module
|
from synapse.util.module_loader import load_module
|
||||||
from synapse.util.stringutils import parse_and_validate_mxc_uri
|
from synapse.util.stringutils import parse_and_validate_mxc_uri
|
||||||
|
|
||||||
from ._base import Config, ConfigError
|
from ._base import Config, ConfigError, read_file
|
||||||
|
|
||||||
DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider"
|
DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider"
|
||||||
|
|
||||||
|
@ -97,7 +97,26 @@ class OIDCConfig(Config):
|
||||||
#
|
#
|
||||||
# client_id: Required. oauth2 client id to use.
|
# client_id: Required. oauth2 client id to use.
|
||||||
#
|
#
|
||||||
# client_secret: Required. oauth2 client secret to use.
|
# client_secret: oauth2 client secret to use. May be omitted if
|
||||||
|
# client_secret_jwt_key is given, or if client_auth_method is 'none'.
|
||||||
|
#
|
||||||
|
# client_secret_jwt_key: Alternative to client_secret: details of a key used
|
||||||
|
# to create a JSON Web Token to be used as an OAuth2 client secret. If
|
||||||
|
# given, must be a dictionary with the following properties:
|
||||||
|
#
|
||||||
|
# key: a pem-encoded signing key. Must be a suitable key for the
|
||||||
|
# algorithm specified. Required unless 'key_file' is given.
|
||||||
|
#
|
||||||
|
# key_file: the path to file containing a pem-encoded signing key file.
|
||||||
|
# Required unless 'key' is given.
|
||||||
|
#
|
||||||
|
# jwt_header: a dictionary giving properties to include in the JWT
|
||||||
|
# header. Must include the key 'alg', giving the algorithm used to
|
||||||
|
# sign the JWT, such as "ES256", using the JWA identifiers in
|
||||||
|
# RFC7518.
|
||||||
|
#
|
||||||
|
# jwt_payload: an optional dictionary giving properties to include in
|
||||||
|
# the JWT payload. Normally this should include an 'iss' key.
|
||||||
#
|
#
|
||||||
# client_auth_method: auth method to use when exchanging the token. Valid
|
# client_auth_method: auth method to use when exchanging the token. Valid
|
||||||
# values are 'client_secret_basic' (default), 'client_secret_post' and
|
# values are 'client_secret_basic' (default), 'client_secret_post' and
|
||||||
|
@ -240,7 +259,7 @@ class OIDCConfig(Config):
|
||||||
# jsonschema definition of the configuration settings for an oidc identity provider
|
# jsonschema definition of the configuration settings for an oidc identity provider
|
||||||
OIDC_PROVIDER_CONFIG_SCHEMA = {
|
OIDC_PROVIDER_CONFIG_SCHEMA = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["issuer", "client_id", "client_secret"],
|
"required": ["issuer", "client_id"],
|
||||||
"properties": {
|
"properties": {
|
||||||
"idp_id": {
|
"idp_id": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
@ -262,6 +281,30 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
|
||||||
"issuer": {"type": "string"},
|
"issuer": {"type": "string"},
|
||||||
"client_id": {"type": "string"},
|
"client_id": {"type": "string"},
|
||||||
"client_secret": {"type": "string"},
|
"client_secret": {"type": "string"},
|
||||||
|
"client_secret_jwt_key": {
|
||||||
|
"type": "object",
|
||||||
|
"required": ["jwt_header"],
|
||||||
|
"oneOf": [
|
||||||
|
{"required": ["key"]},
|
||||||
|
{"required": ["key_file"]},
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"key": {"type": "string"},
|
||||||
|
"key_file": {"type": "string"},
|
||||||
|
"jwt_header": {
|
||||||
|
"type": "object",
|
||||||
|
"required": ["alg"],
|
||||||
|
"properties": {
|
||||||
|
"alg": {"type": "string"},
|
||||||
|
},
|
||||||
|
"additionalProperties": {"type": "string"},
|
||||||
|
},
|
||||||
|
"jwt_payload": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
"client_auth_method": {
|
"client_auth_method": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
# the following list is the same as the keys of
|
# the following list is the same as the keys of
|
||||||
|
@ -404,6 +447,20 @@ def _parse_oidc_config_dict(
|
||||||
"idp_icon must be a valid MXC URI", config_path + ("idp_icon",)
|
"idp_icon must be a valid MXC URI", config_path + ("idp_icon",)
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key")
|
||||||
|
client_secret_jwt_key = None # type: Optional[OidcProviderClientSecretJwtKey]
|
||||||
|
if client_secret_jwt_key_config is not None:
|
||||||
|
keyfile = client_secret_jwt_key_config.get("key_file")
|
||||||
|
if keyfile:
|
||||||
|
key = read_file(keyfile, config_path + ("client_secret_jwt_key",))
|
||||||
|
else:
|
||||||
|
key = client_secret_jwt_key_config["key"]
|
||||||
|
client_secret_jwt_key = OidcProviderClientSecretJwtKey(
|
||||||
|
key=key,
|
||||||
|
jwt_header=client_secret_jwt_key_config["jwt_header"],
|
||||||
|
jwt_payload=client_secret_jwt_key_config.get("jwt_payload", {}),
|
||||||
|
)
|
||||||
|
|
||||||
return OidcProviderConfig(
|
return OidcProviderConfig(
|
||||||
idp_id=idp_id,
|
idp_id=idp_id,
|
||||||
idp_name=oidc_config.get("idp_name", "OIDC"),
|
idp_name=oidc_config.get("idp_name", "OIDC"),
|
||||||
|
@ -412,7 +469,8 @@ def _parse_oidc_config_dict(
|
||||||
discover=oidc_config.get("discover", True),
|
discover=oidc_config.get("discover", True),
|
||||||
issuer=oidc_config["issuer"],
|
issuer=oidc_config["issuer"],
|
||||||
client_id=oidc_config["client_id"],
|
client_id=oidc_config["client_id"],
|
||||||
client_secret=oidc_config["client_secret"],
|
client_secret=oidc_config.get("client_secret"),
|
||||||
|
client_secret_jwt_key=client_secret_jwt_key,
|
||||||
client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
|
client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
|
||||||
scopes=oidc_config.get("scopes", ["openid"]),
|
scopes=oidc_config.get("scopes", ["openid"]),
|
||||||
authorization_endpoint=oidc_config.get("authorization_endpoint"),
|
authorization_endpoint=oidc_config.get("authorization_endpoint"),
|
||||||
|
@ -427,6 +485,18 @@ def _parse_oidc_config_dict(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True)
|
||||||
|
class OidcProviderClientSecretJwtKey:
|
||||||
|
# a pem-encoded signing key
|
||||||
|
key = attr.ib(type=str)
|
||||||
|
|
||||||
|
# properties to include in the JWT header
|
||||||
|
jwt_header = attr.ib(type=Mapping[str, str])
|
||||||
|
|
||||||
|
# properties to include in the JWT payload.
|
||||||
|
jwt_payload = attr.ib(type=Mapping[str, str])
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, frozen=True)
|
@attr.s(slots=True, frozen=True)
|
||||||
class OidcProviderConfig:
|
class OidcProviderConfig:
|
||||||
# a unique identifier for this identity provider. Used in the 'user_external_ids'
|
# a unique identifier for this identity provider. Used in the 'user_external_ids'
|
||||||
|
@ -452,8 +522,13 @@ class OidcProviderConfig:
|
||||||
# oauth2 client id to use
|
# oauth2 client id to use
|
||||||
client_id = attr.ib(type=str)
|
client_id = attr.ib(type=str)
|
||||||
|
|
||||||
# oauth2 client secret to use
|
# oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate
|
||||||
client_secret = attr.ib(type=str)
|
# a secret.
|
||||||
|
client_secret = attr.ib(type=Optional[str])
|
||||||
|
|
||||||
|
# key to use to construct a JWT to use as a client secret. May be `None` if
|
||||||
|
# `client_secret` is set.
|
||||||
|
client_secret_jwt_key = attr.ib(type=Optional[OidcProviderClientSecretJwtKey])
|
||||||
|
|
||||||
# auth method to use when exchanging the token.
|
# auth method to use when exchanging the token.
|
||||||
# Valid values are 'client_secret_basic', 'client_secret_post' and
|
# Valid values are 'client_secret_basic', 'client_secret_post' and
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2020 Quentin Gliech
|
# Copyright 2020 Quentin Gliech
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -14,13 +15,13 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
|
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
from authlib.common.security import generate_token
|
from authlib.common.security import generate_token
|
||||||
from authlib.jose import JsonWebToken
|
from authlib.jose import JsonWebToken, jwt
|
||||||
from authlib.oauth2.auth import ClientAuth
|
from authlib.oauth2.auth import ClientAuth
|
||||||
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
|
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
|
||||||
from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
|
from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
|
||||||
|
@ -35,12 +36,15 @@ from typing_extensions import TypedDict
|
||||||
from twisted.web.client import readBody
|
from twisted.web.client import readBody
|
||||||
|
|
||||||
from synapse.config import ConfigError
|
from synapse.config import ConfigError
|
||||||
from synapse.config.oidc_config import OidcProviderConfig
|
from synapse.config.oidc_config import (
|
||||||
|
OidcProviderClientSecretJwtKey,
|
||||||
|
OidcProviderConfig,
|
||||||
|
)
|
||||||
from synapse.handlers.sso import MappingException, UserAttributes
|
from synapse.handlers.sso import MappingException, UserAttributes
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
|
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
|
||||||
from synapse.util import json_decoder
|
from synapse.util import Clock, json_decoder
|
||||||
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
|
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
|
||||||
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
|
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
|
||||||
|
|
||||||
|
@ -276,9 +280,21 @@ class OidcProvider:
|
||||||
|
|
||||||
self._scopes = provider.scopes
|
self._scopes = provider.scopes
|
||||||
self._user_profile_method = provider.user_profile_method
|
self._user_profile_method = provider.user_profile_method
|
||||||
|
|
||||||
|
client_secret = None # type: Union[None, str, JwtClientSecret]
|
||||||
|
if provider.client_secret:
|
||||||
|
client_secret = provider.client_secret
|
||||||
|
elif provider.client_secret_jwt_key:
|
||||||
|
client_secret = JwtClientSecret(
|
||||||
|
provider.client_secret_jwt_key,
|
||||||
|
provider.client_id,
|
||||||
|
provider.issuer,
|
||||||
|
hs.get_clock(),
|
||||||
|
)
|
||||||
|
|
||||||
self._client_auth = ClientAuth(
|
self._client_auth = ClientAuth(
|
||||||
provider.client_id,
|
provider.client_id,
|
||||||
provider.client_secret,
|
client_secret,
|
||||||
provider.client_auth_method,
|
provider.client_auth_method,
|
||||||
) # type: ClientAuth
|
) # type: ClientAuth
|
||||||
self._client_auth_method = provider.client_auth_method
|
self._client_auth_method = provider.client_auth_method
|
||||||
|
@ -977,6 +993,81 @@ class OidcProvider:
|
||||||
return str(remote_user_id)
|
return str(remote_user_id)
|
||||||
|
|
||||||
|
|
||||||
|
# number of seconds a newly-generated client secret should be valid for
|
||||||
|
CLIENT_SECRET_VALIDITY_SECONDS = 3600
|
||||||
|
|
||||||
|
# minimum remaining validity on a client secret before we should generate a new one
|
||||||
|
CLIENT_SECRET_MIN_VALIDITY_SECONDS = 600
|
||||||
|
|
||||||
|
|
||||||
|
class JwtClientSecret:
|
||||||
|
"""A class which generates a new client secret on demand, based on a JWK
|
||||||
|
|
||||||
|
This implementation is designed to comply with the requirements for Apple Sign in:
|
||||||
|
https://developer.apple.com/documentation/sign_in_with_apple/generate_and_validate_tokens#3262048
|
||||||
|
|
||||||
|
It looks like those requirements are based on https://tools.ietf.org/html/rfc7523,
|
||||||
|
but it's worth noting that we still put the generated secret in the "client_secret"
|
||||||
|
field (or rather, whereever client_auth_method puts it) rather than in a
|
||||||
|
client_assertion field in the body as that RFC seems to require.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
key: OidcProviderClientSecretJwtKey,
|
||||||
|
oauth_client_id: str,
|
||||||
|
oauth_issuer: str,
|
||||||
|
clock: Clock,
|
||||||
|
):
|
||||||
|
self._key = key
|
||||||
|
self._oauth_client_id = oauth_client_id
|
||||||
|
self._oauth_issuer = oauth_issuer
|
||||||
|
self._clock = clock
|
||||||
|
self._cached_secret = b""
|
||||||
|
self._cached_secret_replacement_time = 0
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
# if client_auth_method is client_secret_basic, then ClientAuth.prepare calls
|
||||||
|
# encode_client_secret_basic, which calls "{}".format(secret), which ends up
|
||||||
|
# here.
|
||||||
|
return self._get_secret().decode("ascii")
|
||||||
|
|
||||||
|
def __bytes__(self):
|
||||||
|
# if client_auth_method is client_secret_post, then ClientAuth.prepare calls
|
||||||
|
# encode_client_secret_post, which ends up here.
|
||||||
|
return self._get_secret()
|
||||||
|
|
||||||
|
def _get_secret(self) -> bytes:
|
||||||
|
now = self._clock.time()
|
||||||
|
|
||||||
|
# if we have enough validity on our existing secret, use it
|
||||||
|
if now < self._cached_secret_replacement_time:
|
||||||
|
return self._cached_secret
|
||||||
|
|
||||||
|
issued_at = int(now)
|
||||||
|
expires_at = issued_at + CLIENT_SECRET_VALIDITY_SECONDS
|
||||||
|
|
||||||
|
# we copy the configured header because jwt.encode modifies it.
|
||||||
|
header = dict(self._key.jwt_header)
|
||||||
|
|
||||||
|
# see https://tools.ietf.org/html/rfc7523#section-3
|
||||||
|
payload = {
|
||||||
|
"sub": self._oauth_client_id,
|
||||||
|
"aud": self._oauth_issuer,
|
||||||
|
"iat": issued_at,
|
||||||
|
"exp": expires_at,
|
||||||
|
**self._key.jwt_payload,
|
||||||
|
}
|
||||||
|
logger.info(
|
||||||
|
"Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload
|
||||||
|
)
|
||||||
|
self._cached_secret = jwt.encode(header, payload, self._key.key)
|
||||||
|
self._cached_secret_replacement_time = (
|
||||||
|
expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS
|
||||||
|
)
|
||||||
|
return self._cached_secret
|
||||||
|
|
||||||
|
|
||||||
class OidcSessionTokenGenerator:
|
class OidcSessionTokenGenerator:
|
||||||
"""Methods for generating and checking OIDC Session cookies."""
|
"""Methods for generating and checking OIDC Session cookies."""
|
||||||
|
|
||||||
|
|
5
tests/handlers/oidc_test_key.p8
Normal file
5
tests/handlers/oidc_test_key.p8
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
-----BEGIN PRIVATE KEY-----
|
||||||
|
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgrHMvFcFjFhei6gHp
|
||||||
|
Gfy4C8+6z7634MZbC7SSx4a17GahRANCAATp0YxEzGUXuqszggiFxczDdPgDpCJA
|
||||||
|
P18rRuN7FLwZDuzYQPb8zVd8eGh4BqxjiVocICnVWyaSWD96N00I96SW
|
||||||
|
-----END PRIVATE KEY-----
|
4
tests/handlers/oidc_test_key.pub.pem
Normal file
4
tests/handlers/oidc_test_key.pub.pem
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
-----BEGIN PUBLIC KEY-----
|
||||||
|
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE6dGMRMxlF7qrM4IIhcXMw3T4A6Qi
|
||||||
|
QD9fK0bjexS8GQ7s2ED2/M1XfHhoeAasY4laHCAp1Vsmklg/ejdNCPeklg==
|
||||||
|
-----END PUBLIC KEY-----
|
|
@ -13,6 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from urllib.parse import parse_qs, urlparse
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
from mock import ANY, Mock, patch
|
from mock import ANY, Mock, patch
|
||||||
|
@ -50,7 +51,18 @@ WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
|
||||||
JWKS_URI = ISSUER + ".well-known/jwks.json"
|
JWKS_URI = ISSUER + ".well-known/jwks.json"
|
||||||
|
|
||||||
# config for common cases
|
# config for common cases
|
||||||
COMMON_CONFIG = {
|
DEFAULT_CONFIG = {
|
||||||
|
"enabled": True,
|
||||||
|
"client_id": CLIENT_ID,
|
||||||
|
"client_secret": CLIENT_SECRET,
|
||||||
|
"issuer": ISSUER,
|
||||||
|
"scopes": SCOPES,
|
||||||
|
"user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
|
||||||
|
}
|
||||||
|
|
||||||
|
# extends the default config with explicit OAuth2 endpoints instead of using discovery
|
||||||
|
EXPLICIT_ENDPOINT_CONFIG = {
|
||||||
|
**DEFAULT_CONFIG,
|
||||||
"discover": False,
|
"discover": False,
|
||||||
"authorization_endpoint": AUTHORIZATION_ENDPOINT,
|
"authorization_endpoint": AUTHORIZATION_ENDPOINT,
|
||||||
"token_endpoint": TOKEN_ENDPOINT,
|
"token_endpoint": TOKEN_ENDPOINT,
|
||||||
|
@ -107,6 +119,32 @@ async def get_json(url):
|
||||||
return {"keys": []}
|
return {"keys": []}
|
||||||
|
|
||||||
|
|
||||||
|
def _key_file_path() -> str:
|
||||||
|
"""path to a file containing the private half of a test key"""
|
||||||
|
|
||||||
|
# this key was generated with:
|
||||||
|
# openssl ecparam -name prime256v1 -genkey -noout |
|
||||||
|
# openssl pkcs8 -topk8 -nocrypt -out oidc_test_key.p8
|
||||||
|
#
|
||||||
|
# we use PKCS8 rather than SEC-1 (which is what openssl ecparam spits out), because
|
||||||
|
# that's what Apple use, and we want to be sure that we work with Apple's keys.
|
||||||
|
#
|
||||||
|
# (For the record: both PKCS8 and SEC-1 specify (different) ways of representing
|
||||||
|
# keys using ASN.1. Both are then typically formatted using PEM, which says: use the
|
||||||
|
# base64-encoded DER encoding of ASN.1, with headers and footers. But we don't
|
||||||
|
# really need to care about any of that.)
|
||||||
|
return os.path.join(os.path.dirname(__file__), "oidc_test_key.p8")
|
||||||
|
|
||||||
|
|
||||||
|
def _public_key_file_path() -> str:
|
||||||
|
"""path to a file containing the public half of a test key"""
|
||||||
|
# this was generated with:
|
||||||
|
# openssl ec -in oidc_test_key.p8 -pubout -out oidc_test_key.pub.pem
|
||||||
|
#
|
||||||
|
# See above about where oidc_test_key.p8 came from
|
||||||
|
return os.path.join(os.path.dirname(__file__), "oidc_test_key.pub.pem")
|
||||||
|
|
||||||
|
|
||||||
class OidcHandlerTestCase(HomeserverTestCase):
|
class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
if not HAS_OIDC:
|
if not HAS_OIDC:
|
||||||
skip = "requires OIDC"
|
skip = "requires OIDC"
|
||||||
|
@ -114,20 +152,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
def default_config(self):
|
def default_config(self):
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
config["public_baseurl"] = BASE_URL
|
config["public_baseurl"] = BASE_URL
|
||||||
oidc_config = {
|
|
||||||
"enabled": True,
|
|
||||||
"client_id": CLIENT_ID,
|
|
||||||
"client_secret": CLIENT_SECRET,
|
|
||||||
"issuer": ISSUER,
|
|
||||||
"scopes": SCOPES,
|
|
||||||
"user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Update this config with what's in the default config so that
|
|
||||||
# override_config works as expected.
|
|
||||||
oidc_config.update(config.get("oidc_config", {}))
|
|
||||||
config["oidc_config"] = oidc_config
|
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor, clock):
|
||||||
|
@ -170,13 +194,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.render_error.reset_mock()
|
self.render_error.reset_mock()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_config(self):
|
def test_config(self):
|
||||||
"""Basic config correctly sets up the callback URL and client auth correctly."""
|
"""Basic config correctly sets up the callback URL and client auth correctly."""
|
||||||
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
|
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
|
||||||
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
|
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
|
||||||
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
|
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
|
||||||
|
|
||||||
@override_config({"oidc_config": {"discover": True}})
|
@override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
|
||||||
def test_discovery(self):
|
def test_discovery(self):
|
||||||
"""The handler should discover the endpoints from OIDC discovery document."""
|
"""The handler should discover the endpoints from OIDC discovery document."""
|
||||||
# This would throw if some metadata were invalid
|
# This would throw if some metadata were invalid
|
||||||
|
@ -195,13 +220,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.get_success(self.provider.load_metadata())
|
self.get_success(self.provider.load_metadata())
|
||||||
self.http_client.get_json.assert_not_called()
|
self.http_client.get_json.assert_not_called()
|
||||||
|
|
||||||
@override_config({"oidc_config": COMMON_CONFIG})
|
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
|
||||||
def test_no_discovery(self):
|
def test_no_discovery(self):
|
||||||
"""When discovery is disabled, it should not try to load from discovery document."""
|
"""When discovery is disabled, it should not try to load from discovery document."""
|
||||||
self.get_success(self.provider.load_metadata())
|
self.get_success(self.provider.load_metadata())
|
||||||
self.http_client.get_json.assert_not_called()
|
self.http_client.get_json.assert_not_called()
|
||||||
|
|
||||||
@override_config({"oidc_config": COMMON_CONFIG})
|
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
|
||||||
def test_load_jwks(self):
|
def test_load_jwks(self):
|
||||||
"""JWKS loading is done once (then cached) if used."""
|
"""JWKS loading is done once (then cached) if used."""
|
||||||
jwks = self.get_success(self.provider.load_jwks())
|
jwks = self.get_success(self.provider.load_jwks())
|
||||||
|
@ -236,6 +261,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.http_client.get_json.assert_not_called()
|
self.http_client.get_json.assert_not_called()
|
||||||
self.assertEqual(jwks, {"keys": []})
|
self.assertEqual(jwks, {"keys": []})
|
||||||
|
|
||||||
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_validate_config(self):
|
def test_validate_config(self):
|
||||||
"""Provider metadatas are extensively validated."""
|
"""Provider metadatas are extensively validated."""
|
||||||
h = self.provider
|
h = self.provider
|
||||||
|
@ -318,13 +344,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
# Shouldn't raise with a valid userinfo, even without jwks
|
# Shouldn't raise with a valid userinfo, even without jwks
|
||||||
force_load_metadata()
|
force_load_metadata()
|
||||||
|
|
||||||
@override_config({"oidc_config": {"skip_verification": True}})
|
@override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
|
||||||
def test_skip_verification(self):
|
def test_skip_verification(self):
|
||||||
"""Provider metadata validation can be disabled by config."""
|
"""Provider metadata validation can be disabled by config."""
|
||||||
with self.metadata_edit({"issuer": "http://insecure"}):
|
with self.metadata_edit({"issuer": "http://insecure"}):
|
||||||
# This should not throw
|
# This should not throw
|
||||||
get_awaitable_result(self.provider.load_metadata())
|
get_awaitable_result(self.provider.load_metadata())
|
||||||
|
|
||||||
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_redirect_request(self):
|
def test_redirect_request(self):
|
||||||
"""The redirect request has the right arguments & generates a valid session cookie."""
|
"""The redirect request has the right arguments & generates a valid session cookie."""
|
||||||
req = Mock(spec=["cookies"])
|
req = Mock(spec=["cookies"])
|
||||||
|
@ -368,6 +395,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.assertEqual(params["nonce"], [nonce])
|
self.assertEqual(params["nonce"], [nonce])
|
||||||
self.assertEqual(redirect, "http://client/redirect")
|
self.assertEqual(redirect, "http://client/redirect")
|
||||||
|
|
||||||
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_callback_error(self):
|
def test_callback_error(self):
|
||||||
"""Errors from the provider returned in the callback are displayed."""
|
"""Errors from the provider returned in the callback are displayed."""
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
|
@ -379,6 +407,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("invalid_client", "some description")
|
self.assertRenderedError("invalid_client", "some description")
|
||||||
|
|
||||||
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_callback(self):
|
def test_callback(self):
|
||||||
"""Code callback works and display errors if something went wrong.
|
"""Code callback works and display errors if something went wrong.
|
||||||
|
|
||||||
|
@ -480,6 +509,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("invalid_request")
|
self.assertRenderedError("invalid_request")
|
||||||
|
|
||||||
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_callback_session(self):
|
def test_callback_session(self):
|
||||||
"""The callback verifies the session presence and validity"""
|
"""The callback verifies the session presence and validity"""
|
||||||
request = Mock(spec=["args", "getCookie", "cookies"])
|
request = Mock(spec=["args", "getCookie", "cookies"])
|
||||||
|
@ -522,7 +552,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.get_success(self.handler.handle_oidc_callback(request))
|
self.get_success(self.handler.handle_oidc_callback(request))
|
||||||
self.assertRenderedError("invalid_request")
|
self.assertRenderedError("invalid_request")
|
||||||
|
|
||||||
@override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
|
@override_config(
|
||||||
|
{"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
|
||||||
|
)
|
||||||
def test_exchange_code(self):
|
def test_exchange_code(self):
|
||||||
"""Code exchange behaves correctly and handles various error scenarios."""
|
"""Code exchange behaves correctly and handles various error scenarios."""
|
||||||
token = {"type": "bearer"}
|
token = {"type": "bearer"}
|
||||||
|
@ -607,9 +639,105 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
@override_config(
|
@override_config(
|
||||||
{
|
{
|
||||||
"oidc_config": {
|
"oidc_config": {
|
||||||
|
"enabled": True,
|
||||||
|
"client_id": CLIENT_ID,
|
||||||
|
"issuer": ISSUER,
|
||||||
|
"client_auth_method": "client_secret_post",
|
||||||
|
"client_secret_jwt_key": {
|
||||||
|
"key_file": _key_file_path(),
|
||||||
|
"jwt_header": {"alg": "ES256", "kid": "ABC789"},
|
||||||
|
"jwt_payload": {"iss": "DEFGHI"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_exchange_code_jwt_key(self):
|
||||||
|
"""Test that code exchange works with a JWK client secret."""
|
||||||
|
from authlib.jose import jwt
|
||||||
|
|
||||||
|
token = {"type": "bearer"}
|
||||||
|
self.http_client.request = simple_async_mock(
|
||||||
|
return_value=FakeResponse(
|
||||||
|
code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
code = "code"
|
||||||
|
|
||||||
|
# advance the clock a bit before we start, so we aren't working with zero
|
||||||
|
# timestamps.
|
||||||
|
self.reactor.advance(1000)
|
||||||
|
start_time = self.reactor.seconds()
|
||||||
|
ret = self.get_success(self.provider._exchange_code(code))
|
||||||
|
|
||||||
|
self.assertEqual(ret, token)
|
||||||
|
|
||||||
|
# the request should have hit the token endpoint
|
||||||
|
kwargs = self.http_client.request.call_args[1]
|
||||||
|
self.assertEqual(kwargs["method"], "POST")
|
||||||
|
self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
|
||||||
|
|
||||||
|
# the client secret provided to the should be a jwt which can be checked with
|
||||||
|
# the public key
|
||||||
|
args = parse_qs(kwargs["data"].decode("utf-8"))
|
||||||
|
secret = args["client_secret"][0]
|
||||||
|
with open(_public_key_file_path()) as f:
|
||||||
|
key = f.read()
|
||||||
|
claims = jwt.decode(secret, key)
|
||||||
|
self.assertEqual(claims.header["kid"], "ABC789")
|
||||||
|
self.assertEqual(claims["aud"], ISSUER)
|
||||||
|
self.assertEqual(claims["iss"], "DEFGHI")
|
||||||
|
self.assertEqual(claims["sub"], CLIENT_ID)
|
||||||
|
self.assertEqual(claims["iat"], start_time)
|
||||||
|
self.assertGreater(claims["exp"], start_time)
|
||||||
|
|
||||||
|
# check the rest of the POSTed data
|
||||||
|
self.assertEqual(args["grant_type"], ["authorization_code"])
|
||||||
|
self.assertEqual(args["code"], [code])
|
||||||
|
self.assertEqual(args["client_id"], [CLIENT_ID])
|
||||||
|
self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"oidc_config": {
|
||||||
|
"enabled": True,
|
||||||
|
"client_id": CLIENT_ID,
|
||||||
|
"issuer": ISSUER,
|
||||||
|
"client_auth_method": "none",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_exchange_code_no_auth(self):
|
||||||
|
"""Test that code exchange works with no client secret."""
|
||||||
|
token = {"type": "bearer"}
|
||||||
|
self.http_client.request = simple_async_mock(
|
||||||
|
return_value=FakeResponse(
|
||||||
|
code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
code = "code"
|
||||||
|
ret = self.get_success(self.provider._exchange_code(code))
|
||||||
|
|
||||||
|
self.assertEqual(ret, token)
|
||||||
|
|
||||||
|
# the request should have hit the token endpoint
|
||||||
|
kwargs = self.http_client.request.call_args[1]
|
||||||
|
self.assertEqual(kwargs["method"], "POST")
|
||||||
|
self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
|
||||||
|
|
||||||
|
# check the POSTed data
|
||||||
|
args = parse_qs(kwargs["data"].decode("utf-8"))
|
||||||
|
self.assertEqual(args["grant_type"], ["authorization_code"])
|
||||||
|
self.assertEqual(args["code"], [code])
|
||||||
|
self.assertEqual(args["client_id"], [CLIENT_ID])
|
||||||
|
self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"oidc_config": {
|
||||||
|
**DEFAULT_CONFIG,
|
||||||
"user_mapping_provider": {
|
"user_mapping_provider": {
|
||||||
"module": __name__ + ".TestMappingProviderExtra"
|
"module": __name__ + ".TestMappingProviderExtra"
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -652,6 +780,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
new_user=True,
|
new_user=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_map_userinfo_to_user(self):
|
def test_map_userinfo_to_user(self):
|
||||||
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
|
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
|
@ -692,7 +821,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
"Mapping provider does not support de-duplicating Matrix IDs",
|
"Mapping provider does not support de-duplicating Matrix IDs",
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"oidc_config": {"allow_existing_users": True}})
|
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
|
||||||
def test_map_userinfo_to_existing_user(self):
|
def test_map_userinfo_to_existing_user(self):
|
||||||
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
|
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
|
||||||
store = self.hs.get_datastore()
|
store = self.hs.get_datastore()
|
||||||
|
@ -772,6 +901,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
"@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False
|
"@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_map_userinfo_to_invalid_localpart(self):
|
def test_map_userinfo_to_invalid_localpart(self):
|
||||||
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
@ -782,9 +912,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
@override_config(
|
@override_config(
|
||||||
{
|
{
|
||||||
"oidc_config": {
|
"oidc_config": {
|
||||||
|
**DEFAULT_CONFIG,
|
||||||
"user_mapping_provider": {
|
"user_mapping_provider": {
|
||||||
"module": __name__ + ".TestMappingProviderFailures"
|
"module": __name__ + ".TestMappingProviderFailures"
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -829,6 +960,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
"mapping_error", "Unable to generate a Matrix ID from the SSO response"
|
"mapping_error", "Unable to generate a Matrix ID from the SSO response"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
||||||
def test_empty_localpart(self):
|
def test_empty_localpart(self):
|
||||||
"""Attempts to map onto an empty localpart should be rejected."""
|
"""Attempts to map onto an empty localpart should be rejected."""
|
||||||
userinfo = {
|
userinfo = {
|
||||||
|
@ -841,9 +973,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
@override_config(
|
@override_config(
|
||||||
{
|
{
|
||||||
"oidc_config": {
|
"oidc_config": {
|
||||||
|
**DEFAULT_CONFIG,
|
||||||
"user_mapping_provider": {
|
"user_mapping_provider": {
|
||||||
"config": {"localpart_template": "{{ user.username }}"}
|
"config": {"localpart_template": "{{ user.username }}"}
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue