mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-22 12:44:30 +03:00
23740eaa3d
During the migration the automated script to update the copyright headers accidentally got rid of some of the existing copyright lines. Reinstate them.
355 lines
12 KiB
Python
355 lines
12 KiB
Python
#
|
|
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
|
#
|
|
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
|
# Copyright (C) 2023 New Vector, Ltd
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as
|
|
# published by the Free Software Foundation, either version 3 of the
|
|
# License, or (at your option) any later version.
|
|
#
|
|
# See the GNU Affero General Public License for more details:
|
|
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
|
#
|
|
# Originally licensed under the Apache License, Version 2.0:
|
|
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
|
#
|
|
# [This file includes modifications made by New Vector Limited]
|
|
#
|
|
#
|
|
|
|
|
|
import json
|
|
from typing import Any, ContextManager, Dict, List, Optional, Tuple
|
|
from unittest.mock import Mock, patch
|
|
from urllib.parse import parse_qs
|
|
|
|
import attr
|
|
|
|
from twisted.web.http_headers import Headers
|
|
from twisted.web.iweb import IResponse
|
|
|
|
from synapse.server import HomeServer
|
|
from synapse.util import Clock
|
|
from synapse.util.stringutils import random_string
|
|
|
|
from tests.test_utils import FakeResponse
|
|
|
|
|
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
|
class FakeAuthorizationGrant:
|
|
userinfo: dict
|
|
client_id: str
|
|
redirect_uri: str
|
|
scope: str
|
|
nonce: Optional[str]
|
|
sid: Optional[str]
|
|
|
|
|
|
class FakeOidcServer:
|
|
"""A fake OpenID Connect Provider."""
|
|
|
|
# All methods here are mocks, so we can track when they are called, and override
|
|
# their values
|
|
request: Mock
|
|
get_jwks_handler: Mock
|
|
get_metadata_handler: Mock
|
|
get_userinfo_handler: Mock
|
|
post_token_handler: Mock
|
|
|
|
sid_counter: int = 0
|
|
|
|
def __init__(self, clock: Clock, issuer: str):
|
|
from authlib.jose import ECKey, KeySet
|
|
|
|
self._clock = clock
|
|
self.issuer = issuer
|
|
|
|
self.request = Mock(side_effect=self._request)
|
|
self.get_jwks_handler = Mock(side_effect=self._get_jwks_handler)
|
|
self.get_metadata_handler = Mock(side_effect=self._get_metadata_handler)
|
|
self.get_userinfo_handler = Mock(side_effect=self._get_userinfo_handler)
|
|
self.post_token_handler = Mock(side_effect=self._post_token_handler)
|
|
|
|
# A code -> grant mapping
|
|
self._authorization_grants: Dict[str, FakeAuthorizationGrant] = {}
|
|
# An access token -> grant mapping
|
|
self._sessions: Dict[str, FakeAuthorizationGrant] = {}
|
|
|
|
# We generate here an ECDSA key with the P-256 curve (ES256 algorithm) used for
|
|
# signing JWTs. ECDSA keys are really quick to generate compared to RSA.
|
|
self._key = ECKey.generate_key(crv="P-256", is_private=True)
|
|
self._jwks = KeySet([ECKey.import_key(self._key.as_pem(is_private=False))])
|
|
|
|
self._id_token_overrides: Dict[str, Any] = {}
|
|
|
|
def reset_mocks(self) -> None:
|
|
self.request.reset_mock()
|
|
self.get_jwks_handler.reset_mock()
|
|
self.get_metadata_handler.reset_mock()
|
|
self.get_userinfo_handler.reset_mock()
|
|
self.post_token_handler.reset_mock()
|
|
|
|
def patch_homeserver(self, hs: HomeServer) -> ContextManager[Mock]:
|
|
"""Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``.
|
|
|
|
This patch should be used whenever the HS is expected to perform request to the
|
|
OIDC provider, e.g.::
|
|
|
|
fake_oidc_server = self.helper.fake_oidc_server()
|
|
with fake_oidc_server.patch_homeserver(hs):
|
|
self.make_request("GET", "/_matrix/client/r0/login/sso/redirect")
|
|
"""
|
|
return patch.object(hs.get_proxied_http_client(), "request", self.request)
|
|
|
|
@property
|
|
def authorization_endpoint(self) -> str:
|
|
return self.issuer + "authorize"
|
|
|
|
@property
|
|
def token_endpoint(self) -> str:
|
|
return self.issuer + "token"
|
|
|
|
@property
|
|
def userinfo_endpoint(self) -> str:
|
|
return self.issuer + "userinfo"
|
|
|
|
@property
|
|
def metadata_endpoint(self) -> str:
|
|
return self.issuer + ".well-known/openid-configuration"
|
|
|
|
@property
|
|
def jwks_uri(self) -> str:
|
|
return self.issuer + "jwks"
|
|
|
|
def get_metadata(self) -> dict:
|
|
return {
|
|
"issuer": self.issuer,
|
|
"authorization_endpoint": self.authorization_endpoint,
|
|
"token_endpoint": self.token_endpoint,
|
|
"jwks_uri": self.jwks_uri,
|
|
"userinfo_endpoint": self.userinfo_endpoint,
|
|
"response_types_supported": ["code"],
|
|
"subject_types_supported": ["public"],
|
|
"id_token_signing_alg_values_supported": ["ES256"],
|
|
}
|
|
|
|
def get_jwks(self) -> dict:
|
|
return self._jwks.as_dict()
|
|
|
|
def get_userinfo(self, access_token: str) -> Optional[dict]:
|
|
"""Given an access token, get the userinfo of the associated session."""
|
|
session = self._sessions.get(access_token, None)
|
|
if session is None:
|
|
return None
|
|
return session.userinfo
|
|
|
|
def _sign(self, payload: dict) -> str:
|
|
from authlib.jose import JsonWebSignature
|
|
|
|
jws = JsonWebSignature()
|
|
kid = self.get_jwks()["keys"][0]["kid"]
|
|
protected = {"alg": "ES256", "kid": kid}
|
|
json_payload = json.dumps(payload)
|
|
return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8")
|
|
|
|
def generate_id_token(self, grant: FakeAuthorizationGrant) -> str:
|
|
now = int(self._clock.time())
|
|
id_token = {
|
|
**grant.userinfo,
|
|
"iss": self.issuer,
|
|
"aud": grant.client_id,
|
|
"iat": now,
|
|
"nbf": now,
|
|
"exp": now + 600,
|
|
}
|
|
|
|
if grant.nonce is not None:
|
|
id_token["nonce"] = grant.nonce
|
|
|
|
if grant.sid is not None:
|
|
id_token["sid"] = grant.sid
|
|
|
|
id_token.update(self._id_token_overrides)
|
|
|
|
return self._sign(id_token)
|
|
|
|
def generate_logout_token(self, grant: FakeAuthorizationGrant) -> str:
|
|
now = int(self._clock.time())
|
|
logout_token = {
|
|
"iss": self.issuer,
|
|
"aud": grant.client_id,
|
|
"iat": now,
|
|
"jti": random_string(10),
|
|
"events": {
|
|
"http://schemas.openid.net/event/backchannel-logout": {},
|
|
},
|
|
}
|
|
|
|
if grant.sid is not None:
|
|
logout_token["sid"] = grant.sid
|
|
|
|
if "sub" in grant.userinfo:
|
|
logout_token["sub"] = grant.userinfo["sub"]
|
|
|
|
return self._sign(logout_token)
|
|
|
|
def id_token_override(self, overrides: dict) -> ContextManager[dict]:
|
|
"""Temporarily patch the ID token generated by the token endpoint."""
|
|
return patch.object(self, "_id_token_overrides", overrides)
|
|
|
|
def start_authorization(
|
|
self,
|
|
client_id: str,
|
|
scope: str,
|
|
redirect_uri: str,
|
|
userinfo: dict,
|
|
nonce: Optional[str] = None,
|
|
with_sid: bool = False,
|
|
) -> Tuple[str, FakeAuthorizationGrant]:
|
|
"""Start an authorization request, and get back the code to use on the authorization endpoint."""
|
|
code = random_string(10)
|
|
sid = None
|
|
if with_sid:
|
|
sid = str(self.sid_counter)
|
|
self.sid_counter += 1
|
|
|
|
grant = FakeAuthorizationGrant(
|
|
userinfo=userinfo,
|
|
scope=scope,
|
|
redirect_uri=redirect_uri,
|
|
nonce=nonce,
|
|
client_id=client_id,
|
|
sid=sid,
|
|
)
|
|
self._authorization_grants[code] = grant
|
|
|
|
return code, grant
|
|
|
|
def exchange_code(self, code: str) -> Optional[Dict[str, Any]]:
|
|
grant = self._authorization_grants.pop(code, None)
|
|
if grant is None:
|
|
return None
|
|
|
|
access_token = random_string(10)
|
|
self._sessions[access_token] = grant
|
|
|
|
token = {
|
|
"token_type": "Bearer",
|
|
"access_token": access_token,
|
|
"expires_in": 3600,
|
|
"scope": grant.scope,
|
|
}
|
|
|
|
if "openid" in grant.scope:
|
|
token["id_token"] = self.generate_id_token(grant)
|
|
|
|
return dict(token)
|
|
|
|
def buggy_endpoint(
|
|
self,
|
|
*,
|
|
jwks: bool = False,
|
|
metadata: bool = False,
|
|
token: bool = False,
|
|
userinfo: bool = False,
|
|
) -> ContextManager[Dict[str, Mock]]:
|
|
"""A context which makes a set of endpoints return a 500 error.
|
|
|
|
Args:
|
|
jwks: If True, makes the JWKS endpoint return a 500 error.
|
|
metadata: If True, makes the OIDC Discovery endpoint return a 500 error.
|
|
token: If True, makes the token endpoint return a 500 error.
|
|
userinfo: If True, makes the userinfo endpoint return a 500 error.
|
|
"""
|
|
buggy = FakeResponse(code=500, body=b"Internal server error")
|
|
|
|
patches = {}
|
|
if jwks:
|
|
patches["get_jwks_handler"] = Mock(return_value=buggy)
|
|
if metadata:
|
|
patches["get_metadata_handler"] = Mock(return_value=buggy)
|
|
if token:
|
|
patches["post_token_handler"] = Mock(return_value=buggy)
|
|
if userinfo:
|
|
patches["get_userinfo_handler"] = Mock(return_value=buggy)
|
|
|
|
return patch.multiple(self, **patches)
|
|
|
|
async def _request(
|
|
self,
|
|
method: str,
|
|
uri: str,
|
|
data: Optional[bytes] = None,
|
|
headers: Optional[Headers] = None,
|
|
) -> IResponse:
|
|
"""The override of the SimpleHttpClient#request() method"""
|
|
access_token: Optional[str] = None
|
|
|
|
if headers is None:
|
|
headers = Headers()
|
|
|
|
# Try to find the access token in the headers if any
|
|
auth_headers = headers.getRawHeaders(b"Authorization")
|
|
if auth_headers:
|
|
parts = auth_headers[0].split(b" ")
|
|
if parts[0] == b"Bearer" and len(parts) == 2:
|
|
access_token = parts[1].decode("ascii")
|
|
|
|
if method == "POST":
|
|
# If the method is POST, assume it has an url-encoded body
|
|
if data is None or headers.getRawHeaders(b"Content-Type") != [
|
|
b"application/x-www-form-urlencoded"
|
|
]:
|
|
return FakeResponse.json(code=400, payload={"error": "invalid_request"})
|
|
|
|
params = parse_qs(data.decode("utf-8"))
|
|
|
|
if uri == self.token_endpoint:
|
|
# Even though this endpoint should be protected, this does not check
|
|
# for client authentication. We're not checking it for simplicity,
|
|
# and because client authentication is tested in other standalone tests.
|
|
return self.post_token_handler(params)
|
|
|
|
elif method == "GET":
|
|
if uri == self.jwks_uri:
|
|
return self.get_jwks_handler()
|
|
elif uri == self.metadata_endpoint:
|
|
return self.get_metadata_handler()
|
|
elif uri == self.userinfo_endpoint:
|
|
return self.get_userinfo_handler(access_token=access_token)
|
|
|
|
return FakeResponse(code=404, body=b"404 not found")
|
|
|
|
# Request handlers
|
|
def _get_jwks_handler(self) -> IResponse:
|
|
"""Handles requests to the JWKS URI."""
|
|
return FakeResponse.json(payload=self.get_jwks())
|
|
|
|
def _get_metadata_handler(self) -> IResponse:
|
|
"""Handles requests to the OIDC well-known document."""
|
|
return FakeResponse.json(payload=self.get_metadata())
|
|
|
|
def _get_userinfo_handler(self, access_token: Optional[str]) -> IResponse:
|
|
"""Handles requests to the userinfo endpoint."""
|
|
if access_token is None:
|
|
return FakeResponse(code=401)
|
|
user_info = self.get_userinfo(access_token)
|
|
if user_info is None:
|
|
return FakeResponse(code=401)
|
|
|
|
return FakeResponse.json(payload=user_info)
|
|
|
|
def _post_token_handler(self, params: Dict[str, List[str]]) -> IResponse:
|
|
"""Handles requests to the token endpoint."""
|
|
code = params.get("code", [])
|
|
|
|
if len(code) != 1:
|
|
return FakeResponse.json(code=400, payload={"error": "invalid_request"})
|
|
|
|
grant = self.exchange_code(code=code[0])
|
|
if grant is None:
|
|
return FakeResponse.json(code=400, payload={"error": "invalid_grant"})
|
|
|
|
return FakeResponse.json(payload=grant)
|