mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-21 17:15:38 +03:00
Added display_name_claim in jwt_config which sets the user's display name upon registration (#17708)
This commit is contained in:
parent
60aebdb27e
commit
05576f0b4b
6 changed files with 50 additions and 6 deletions
1
changelog.d/17708.feature
Normal file
1
changelog.d/17708.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Added the `display_name_claim` option to the JWT configuration. This option allows specifying the claim key that contains the user's display name in the JWT payload.
|
|
@ -3722,6 +3722,8 @@ Additional sub-options for this setting include:
|
|||
Required if `enabled` is set to true.
|
||||
* `subject_claim`: Name of the claim containing a unique identifier for the user.
|
||||
Optional, defaults to `sub`.
|
||||
* `display_name_claim`: Name of the claim containing the display name for the user. Optional.
|
||||
If provided, the display name will be set to the value of this claim upon first login.
|
||||
* `issuer`: The issuer to validate the "iss" claim against. Optional. If provided the
|
||||
"iss" claim will be required and validated for all JSON web tokens.
|
||||
* `audiences`: A list of audiences to validate the "aud" claim against. Optional.
|
||||
|
@ -3736,6 +3738,7 @@ jwt_config:
|
|||
secret: "provided-by-your-issuer"
|
||||
algorithm: "provided-by-your-issuer"
|
||||
subject_claim: "name_of_claim"
|
||||
display_name_claim: "name_of_claim"
|
||||
issuer: "provided-by-your-issuer"
|
||||
audiences:
|
||||
- "provided-by-your-issuer"
|
||||
|
|
|
@ -38,6 +38,7 @@ class JWTConfig(Config):
|
|||
self.jwt_algorithm = jwt_config["algorithm"]
|
||||
|
||||
self.jwt_subject_claim = jwt_config.get("subject_claim", "sub")
|
||||
self.jwt_display_name_claim = jwt_config.get("display_name_claim")
|
||||
|
||||
# The issuer and audiences are optional, if provided, it is asserted
|
||||
# that the claims exist on the JWT.
|
||||
|
@ -49,5 +50,6 @@ class JWTConfig(Config):
|
|||
self.jwt_secret = None
|
||||
self.jwt_algorithm = None
|
||||
self.jwt_subject_claim = None
|
||||
self.jwt_display_name_claim = None
|
||||
self.jwt_issuer = None
|
||||
self.jwt_audiences = None
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
# [This file includes modifications made by New Vector Limited]
|
||||
#
|
||||
#
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from authlib.jose import JsonWebToken, JWTClaims
|
||||
from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError
|
||||
|
@ -36,11 +36,12 @@ class JwtHandler:
|
|||
|
||||
self.jwt_secret = hs.config.jwt.jwt_secret
|
||||
self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim
|
||||
self.jwt_display_name_claim = hs.config.jwt.jwt_display_name_claim
|
||||
self.jwt_algorithm = hs.config.jwt.jwt_algorithm
|
||||
self.jwt_issuer = hs.config.jwt.jwt_issuer
|
||||
self.jwt_audiences = hs.config.jwt.jwt_audiences
|
||||
|
||||
def validate_login(self, login_submission: JsonDict) -> str:
|
||||
def validate_login(self, login_submission: JsonDict) -> Tuple[str, Optional[str]]:
|
||||
"""
|
||||
Authenticates the user for the /login API
|
||||
|
||||
|
@ -49,7 +50,8 @@ class JwtHandler:
|
|||
(including 'type' and other relevant fields)
|
||||
|
||||
Returns:
|
||||
The user ID that is logging in.
|
||||
A tuple of (user_id, display_name) of the user that is logging in.
|
||||
If the JWT does not contain a display name, the second element of the tuple will be None.
|
||||
|
||||
Raises:
|
||||
LoginError if there was an authentication problem.
|
||||
|
@ -109,4 +111,10 @@ class JwtHandler:
|
|||
if user is None:
|
||||
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
|
||||
|
||||
return UserID(user, self.hs.hostname).to_string()
|
||||
default_display_name = None
|
||||
if self.jwt_display_name_claim:
|
||||
display_name_claim = claims.get(self.jwt_display_name_claim)
|
||||
if display_name_claim is not None:
|
||||
default_display_name = display_name_claim
|
||||
|
||||
return UserID(user, self.hs.hostname).to_string(), default_display_name
|
||||
|
|
|
@ -363,6 +363,7 @@ class LoginRestServlet(RestServlet):
|
|||
login_submission: JsonDict,
|
||||
callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None,
|
||||
create_non_existent_users: bool = False,
|
||||
default_display_name: Optional[str] = None,
|
||||
ratelimit: bool = True,
|
||||
auth_provider_id: Optional[str] = None,
|
||||
should_issue_refresh_token: bool = False,
|
||||
|
@ -410,7 +411,8 @@ class LoginRestServlet(RestServlet):
|
|||
canonical_uid = await self.auth_handler.check_user_exists(user_id)
|
||||
if not canonical_uid:
|
||||
canonical_uid = await self.registration_handler.register_user(
|
||||
localpart=UserID.from_string(user_id).localpart
|
||||
localpart=UserID.from_string(user_id).localpart,
|
||||
default_display_name=default_display_name,
|
||||
)
|
||||
user_id = canonical_uid
|
||||
|
||||
|
@ -546,11 +548,14 @@ class LoginRestServlet(RestServlet):
|
|||
Returns:
|
||||
The body of the JSON response.
|
||||
"""
|
||||
user_id = self.hs.get_jwt_handler().validate_login(login_submission)
|
||||
user_id, default_display_name = self.hs.get_jwt_handler().validate_login(
|
||||
login_submission
|
||||
)
|
||||
return await self._complete_login(
|
||||
user_id,
|
||||
login_submission,
|
||||
create_non_existent_users=True,
|
||||
default_display_name=default_display_name,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
request_info=request_info,
|
||||
)
|
||||
|
|
|
@ -1047,6 +1047,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
|||
servlets = [
|
||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||
login.register_servlets,
|
||||
profile.register_servlets,
|
||||
]
|
||||
|
||||
jwt_secret = "secret"
|
||||
|
@ -1202,6 +1203,30 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["user_id"], "@frog:test")
|
||||
|
||||
@override_config(
|
||||
{"jwt_config": {**base_config, "display_name_claim": "display_name"}}
|
||||
)
|
||||
def test_login_custom_display_name(self) -> None:
|
||||
"""Test setting a custom display name."""
|
||||
localpart = "pinkie"
|
||||
user_id = f"@{localpart}:test"
|
||||
display_name = "Pinkie Pie"
|
||||
|
||||
# Perform the login, specifying a custom display name.
|
||||
channel = self.jwt_login({"sub": localpart, "display_name": display_name})
|
||||
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["user_id"], user_id)
|
||||
|
||||
# Fetch the user's display name and check that it was set correctly.
|
||||
access_token = channel.json_body["access_token"]
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/v3/profile/{user_id}/displayname",
|
||||
access_token=access_token,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, msg=channel.result)
|
||||
self.assertEqual(channel.json_body["displayname"], display_name)
|
||||
|
||||
def test_login_no_token(self) -> None:
|
||||
params = {"type": "org.matrix.login.jwt"}
|
||||
channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||
|
|
Loading…
Reference in a new issue