Added display_name_claim in jwt_config which sets the user's display name upon registration (#17708)

This commit is contained in:
Nathan 2024-10-09 14:21:08 +02:00 committed by GitHub
parent 60aebdb27e
commit 05576f0b4b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 50 additions and 6 deletions

View file

@ -0,0 +1 @@
Added the `display_name_claim` option to the JWT configuration. This option allows specifying the claim key that contains the user's display name in the JWT payload.

View file

@ -3722,6 +3722,8 @@ Additional sub-options for this setting include:
Required if `enabled` is set to true. Required if `enabled` is set to true.
* `subject_claim`: Name of the claim containing a unique identifier for the user. * `subject_claim`: Name of the claim containing a unique identifier for the user.
Optional, defaults to `sub`. Optional, defaults to `sub`.
* `display_name_claim`: Name of the claim containing the display name for the user. Optional.
If provided, the display name will be set to the value of this claim upon first login.
* `issuer`: The issuer to validate the "iss" claim against. Optional. If provided the * `issuer`: The issuer to validate the "iss" claim against. Optional. If provided the
"iss" claim will be required and validated for all JSON web tokens. "iss" claim will be required and validated for all JSON web tokens.
* `audiences`: A list of audiences to validate the "aud" claim against. Optional. * `audiences`: A list of audiences to validate the "aud" claim against. Optional.
@ -3736,6 +3738,7 @@ jwt_config:
secret: "provided-by-your-issuer" secret: "provided-by-your-issuer"
algorithm: "provided-by-your-issuer" algorithm: "provided-by-your-issuer"
subject_claim: "name_of_claim" subject_claim: "name_of_claim"
display_name_claim: "name_of_claim"
issuer: "provided-by-your-issuer" issuer: "provided-by-your-issuer"
audiences: audiences:
- "provided-by-your-issuer" - "provided-by-your-issuer"

View file

@ -38,6 +38,7 @@ class JWTConfig(Config):
self.jwt_algorithm = jwt_config["algorithm"] self.jwt_algorithm = jwt_config["algorithm"]
self.jwt_subject_claim = jwt_config.get("subject_claim", "sub") self.jwt_subject_claim = jwt_config.get("subject_claim", "sub")
self.jwt_display_name_claim = jwt_config.get("display_name_claim")
# The issuer and audiences are optional, if provided, it is asserted # The issuer and audiences are optional, if provided, it is asserted
# that the claims exist on the JWT. # that the claims exist on the JWT.
@ -49,5 +50,6 @@ class JWTConfig(Config):
self.jwt_secret = None self.jwt_secret = None
self.jwt_algorithm = None self.jwt_algorithm = None
self.jwt_subject_claim = None self.jwt_subject_claim = None
self.jwt_display_name_claim = None
self.jwt_issuer = None self.jwt_issuer = None
self.jwt_audiences = None self.jwt_audiences = None

View file

@ -18,7 +18,7 @@
# [This file includes modifications made by New Vector Limited] # [This file includes modifications made by New Vector Limited]
# #
# #
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional, Tuple
from authlib.jose import JsonWebToken, JWTClaims from authlib.jose import JsonWebToken, JWTClaims
from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError
@ -36,11 +36,12 @@ class JwtHandler:
self.jwt_secret = hs.config.jwt.jwt_secret self.jwt_secret = hs.config.jwt.jwt_secret
self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim
self.jwt_display_name_claim = hs.config.jwt.jwt_display_name_claim
self.jwt_algorithm = hs.config.jwt.jwt_algorithm self.jwt_algorithm = hs.config.jwt.jwt_algorithm
self.jwt_issuer = hs.config.jwt.jwt_issuer self.jwt_issuer = hs.config.jwt.jwt_issuer
self.jwt_audiences = hs.config.jwt.jwt_audiences self.jwt_audiences = hs.config.jwt.jwt_audiences
def validate_login(self, login_submission: JsonDict) -> str: def validate_login(self, login_submission: JsonDict) -> Tuple[str, Optional[str]]:
""" """
Authenticates the user for the /login API Authenticates the user for the /login API
@ -49,7 +50,8 @@ class JwtHandler:
(including 'type' and other relevant fields) (including 'type' and other relevant fields)
Returns: Returns:
The user ID that is logging in. A tuple of (user_id, display_name) of the user that is logging in.
If the JWT does not contain a display name, the second element of the tuple will be None.
Raises: Raises:
LoginError if there was an authentication problem. LoginError if there was an authentication problem.
@ -109,4 +111,10 @@ class JwtHandler:
if user is None: if user is None:
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN) raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
return UserID(user, self.hs.hostname).to_string() default_display_name = None
if self.jwt_display_name_claim:
display_name_claim = claims.get(self.jwt_display_name_claim)
if display_name_claim is not None:
default_display_name = display_name_claim
return UserID(user, self.hs.hostname).to_string(), default_display_name

View file

@ -363,6 +363,7 @@ class LoginRestServlet(RestServlet):
login_submission: JsonDict, login_submission: JsonDict,
callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None, callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None,
create_non_existent_users: bool = False, create_non_existent_users: bool = False,
default_display_name: Optional[str] = None,
ratelimit: bool = True, ratelimit: bool = True,
auth_provider_id: Optional[str] = None, auth_provider_id: Optional[str] = None,
should_issue_refresh_token: bool = False, should_issue_refresh_token: bool = False,
@ -410,7 +411,8 @@ class LoginRestServlet(RestServlet):
canonical_uid = await self.auth_handler.check_user_exists(user_id) canonical_uid = await self.auth_handler.check_user_exists(user_id)
if not canonical_uid: if not canonical_uid:
canonical_uid = await self.registration_handler.register_user( canonical_uid = await self.registration_handler.register_user(
localpart=UserID.from_string(user_id).localpart localpart=UserID.from_string(user_id).localpart,
default_display_name=default_display_name,
) )
user_id = canonical_uid user_id = canonical_uid
@ -546,11 +548,14 @@ class LoginRestServlet(RestServlet):
Returns: Returns:
The body of the JSON response. The body of the JSON response.
""" """
user_id = self.hs.get_jwt_handler().validate_login(login_submission) user_id, default_display_name = self.hs.get_jwt_handler().validate_login(
login_submission
)
return await self._complete_login( return await self._complete_login(
user_id, user_id,
login_submission, login_submission,
create_non_existent_users=True, create_non_existent_users=True,
default_display_name=default_display_name,
should_issue_refresh_token=should_issue_refresh_token, should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info, request_info=request_info,
) )

View file

@ -1047,6 +1047,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
servlets = [ servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource, synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets, login.register_servlets,
profile.register_servlets,
] ]
jwt_secret = "secret" jwt_secret = "secret"
@ -1202,6 +1203,30 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test") self.assertEqual(channel.json_body["user_id"], "@frog:test")
@override_config(
{"jwt_config": {**base_config, "display_name_claim": "display_name"}}
)
def test_login_custom_display_name(self) -> None:
"""Test setting a custom display name."""
localpart = "pinkie"
user_id = f"@{localpart}:test"
display_name = "Pinkie Pie"
# Perform the login, specifying a custom display name.
channel = self.jwt_login({"sub": localpart, "display_name": display_name})
self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], user_id)
# Fetch the user's display name and check that it was set correctly.
access_token = channel.json_body["access_token"]
channel = self.make_request(
"GET",
f"/_matrix/client/v3/profile/{user_id}/displayname",
access_token=access_token,
)
self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["displayname"], display_name)
def test_login_no_token(self) -> None: def test_login_no_token(self) -> None:
params = {"type": "org.matrix.login.jwt"} params = {"type": "org.matrix.login.jwt"}
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)