mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-22 01:25:44 +03:00
Added display_name_claim in jwt_config which sets the user's display name upon registration (#17708)
This commit is contained in:
parent
60aebdb27e
commit
05576f0b4b
6 changed files with 50 additions and 6 deletions
1
changelog.d/17708.feature
Normal file
1
changelog.d/17708.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Added the `display_name_claim` option to the JWT configuration. This option allows specifying the claim key that contains the user's display name in the JWT payload.
|
|
@ -3722,6 +3722,8 @@ Additional sub-options for this setting include:
|
||||||
Required if `enabled` is set to true.
|
Required if `enabled` is set to true.
|
||||||
* `subject_claim`: Name of the claim containing a unique identifier for the user.
|
* `subject_claim`: Name of the claim containing a unique identifier for the user.
|
||||||
Optional, defaults to `sub`.
|
Optional, defaults to `sub`.
|
||||||
|
* `display_name_claim`: Name of the claim containing the display name for the user. Optional.
|
||||||
|
If provided, the display name will be set to the value of this claim upon first login.
|
||||||
* `issuer`: The issuer to validate the "iss" claim against. Optional. If provided the
|
* `issuer`: The issuer to validate the "iss" claim against. Optional. If provided the
|
||||||
"iss" claim will be required and validated for all JSON web tokens.
|
"iss" claim will be required and validated for all JSON web tokens.
|
||||||
* `audiences`: A list of audiences to validate the "aud" claim against. Optional.
|
* `audiences`: A list of audiences to validate the "aud" claim against. Optional.
|
||||||
|
@ -3736,6 +3738,7 @@ jwt_config:
|
||||||
secret: "provided-by-your-issuer"
|
secret: "provided-by-your-issuer"
|
||||||
algorithm: "provided-by-your-issuer"
|
algorithm: "provided-by-your-issuer"
|
||||||
subject_claim: "name_of_claim"
|
subject_claim: "name_of_claim"
|
||||||
|
display_name_claim: "name_of_claim"
|
||||||
issuer: "provided-by-your-issuer"
|
issuer: "provided-by-your-issuer"
|
||||||
audiences:
|
audiences:
|
||||||
- "provided-by-your-issuer"
|
- "provided-by-your-issuer"
|
||||||
|
|
|
@ -38,6 +38,7 @@ class JWTConfig(Config):
|
||||||
self.jwt_algorithm = jwt_config["algorithm"]
|
self.jwt_algorithm = jwt_config["algorithm"]
|
||||||
|
|
||||||
self.jwt_subject_claim = jwt_config.get("subject_claim", "sub")
|
self.jwt_subject_claim = jwt_config.get("subject_claim", "sub")
|
||||||
|
self.jwt_display_name_claim = jwt_config.get("display_name_claim")
|
||||||
|
|
||||||
# The issuer and audiences are optional, if provided, it is asserted
|
# The issuer and audiences are optional, if provided, it is asserted
|
||||||
# that the claims exist on the JWT.
|
# that the claims exist on the JWT.
|
||||||
|
@ -49,5 +50,6 @@ class JWTConfig(Config):
|
||||||
self.jwt_secret = None
|
self.jwt_secret = None
|
||||||
self.jwt_algorithm = None
|
self.jwt_algorithm = None
|
||||||
self.jwt_subject_claim = None
|
self.jwt_subject_claim = None
|
||||||
|
self.jwt_display_name_claim = None
|
||||||
self.jwt_issuer = None
|
self.jwt_issuer = None
|
||||||
self.jwt_audiences = None
|
self.jwt_audiences = None
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
# [This file includes modifications made by New Vector Limited]
|
# [This file includes modifications made by New Vector Limited]
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Optional, Tuple
|
||||||
|
|
||||||
from authlib.jose import JsonWebToken, JWTClaims
|
from authlib.jose import JsonWebToken, JWTClaims
|
||||||
from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError
|
from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError
|
||||||
|
@ -36,11 +36,12 @@ class JwtHandler:
|
||||||
|
|
||||||
self.jwt_secret = hs.config.jwt.jwt_secret
|
self.jwt_secret = hs.config.jwt.jwt_secret
|
||||||
self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim
|
self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim
|
||||||
|
self.jwt_display_name_claim = hs.config.jwt.jwt_display_name_claim
|
||||||
self.jwt_algorithm = hs.config.jwt.jwt_algorithm
|
self.jwt_algorithm = hs.config.jwt.jwt_algorithm
|
||||||
self.jwt_issuer = hs.config.jwt.jwt_issuer
|
self.jwt_issuer = hs.config.jwt.jwt_issuer
|
||||||
self.jwt_audiences = hs.config.jwt.jwt_audiences
|
self.jwt_audiences = hs.config.jwt.jwt_audiences
|
||||||
|
|
||||||
def validate_login(self, login_submission: JsonDict) -> str:
|
def validate_login(self, login_submission: JsonDict) -> Tuple[str, Optional[str]]:
|
||||||
"""
|
"""
|
||||||
Authenticates the user for the /login API
|
Authenticates the user for the /login API
|
||||||
|
|
||||||
|
@ -49,7 +50,8 @@ class JwtHandler:
|
||||||
(including 'type' and other relevant fields)
|
(including 'type' and other relevant fields)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The user ID that is logging in.
|
A tuple of (user_id, display_name) of the user that is logging in.
|
||||||
|
If the JWT does not contain a display name, the second element of the tuple will be None.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
LoginError if there was an authentication problem.
|
LoginError if there was an authentication problem.
|
||||||
|
@ -109,4 +111,10 @@ class JwtHandler:
|
||||||
if user is None:
|
if user is None:
|
||||||
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
|
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
return UserID(user, self.hs.hostname).to_string()
|
default_display_name = None
|
||||||
|
if self.jwt_display_name_claim:
|
||||||
|
display_name_claim = claims.get(self.jwt_display_name_claim)
|
||||||
|
if display_name_claim is not None:
|
||||||
|
default_display_name = display_name_claim
|
||||||
|
|
||||||
|
return UserID(user, self.hs.hostname).to_string(), default_display_name
|
||||||
|
|
|
@ -363,6 +363,7 @@ class LoginRestServlet(RestServlet):
|
||||||
login_submission: JsonDict,
|
login_submission: JsonDict,
|
||||||
callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None,
|
callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None,
|
||||||
create_non_existent_users: bool = False,
|
create_non_existent_users: bool = False,
|
||||||
|
default_display_name: Optional[str] = None,
|
||||||
ratelimit: bool = True,
|
ratelimit: bool = True,
|
||||||
auth_provider_id: Optional[str] = None,
|
auth_provider_id: Optional[str] = None,
|
||||||
should_issue_refresh_token: bool = False,
|
should_issue_refresh_token: bool = False,
|
||||||
|
@ -410,7 +411,8 @@ class LoginRestServlet(RestServlet):
|
||||||
canonical_uid = await self.auth_handler.check_user_exists(user_id)
|
canonical_uid = await self.auth_handler.check_user_exists(user_id)
|
||||||
if not canonical_uid:
|
if not canonical_uid:
|
||||||
canonical_uid = await self.registration_handler.register_user(
|
canonical_uid = await self.registration_handler.register_user(
|
||||||
localpart=UserID.from_string(user_id).localpart
|
localpart=UserID.from_string(user_id).localpart,
|
||||||
|
default_display_name=default_display_name,
|
||||||
)
|
)
|
||||||
user_id = canonical_uid
|
user_id = canonical_uid
|
||||||
|
|
||||||
|
@ -546,11 +548,14 @@ class LoginRestServlet(RestServlet):
|
||||||
Returns:
|
Returns:
|
||||||
The body of the JSON response.
|
The body of the JSON response.
|
||||||
"""
|
"""
|
||||||
user_id = self.hs.get_jwt_handler().validate_login(login_submission)
|
user_id, default_display_name = self.hs.get_jwt_handler().validate_login(
|
||||||
|
login_submission
|
||||||
|
)
|
||||||
return await self._complete_login(
|
return await self._complete_login(
|
||||||
user_id,
|
user_id,
|
||||||
login_submission,
|
login_submission,
|
||||||
create_non_existent_users=True,
|
create_non_existent_users=True,
|
||||||
|
default_display_name=default_display_name,
|
||||||
should_issue_refresh_token=should_issue_refresh_token,
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
request_info=request_info,
|
request_info=request_info,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1047,6 +1047,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
servlets = [
|
servlets = [
|
||||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
|
profile.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
jwt_secret = "secret"
|
jwt_secret = "secret"
|
||||||
|
@ -1202,6 +1203,30 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.code, 200, msg=channel.result)
|
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||||
self.assertEqual(channel.json_body["user_id"], "@frog:test")
|
self.assertEqual(channel.json_body["user_id"], "@frog:test")
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{"jwt_config": {**base_config, "display_name_claim": "display_name"}}
|
||||||
|
)
|
||||||
|
def test_login_custom_display_name(self) -> None:
|
||||||
|
"""Test setting a custom display name."""
|
||||||
|
localpart = "pinkie"
|
||||||
|
user_id = f"@{localpart}:test"
|
||||||
|
display_name = "Pinkie Pie"
|
||||||
|
|
||||||
|
# Perform the login, specifying a custom display name.
|
||||||
|
channel = self.jwt_login({"sub": localpart, "display_name": display_name})
|
||||||
|
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||||
|
self.assertEqual(channel.json_body["user_id"], user_id)
|
||||||
|
|
||||||
|
# Fetch the user's display name and check that it was set correctly.
|
||||||
|
access_token = channel.json_body["access_token"]
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"/_matrix/client/v3/profile/{user_id}/displayname",
|
||||||
|
access_token=access_token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||||
|
self.assertEqual(channel.json_body["displayname"], display_name)
|
||||||
|
|
||||||
def test_login_no_token(self) -> None:
|
def test_login_no_token(self) -> None:
|
||||||
params = {"type": "org.matrix.login.jwt"}
|
params = {"type": "org.matrix.login.jwt"}
|
||||||
channel = self.make_request(b"POST", LOGIN_URL, params)
|
channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||||
|
|
Loading…
Reference in a new issue