mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-29 15:39:00 +03:00
parent
3af0672350
commit
a8eceb01e5
4 changed files with 27 additions and 10 deletions
1
changelog.d/8920.bugfix
Normal file
1
changelog.d/8920.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix login API to not ratelimit application services that have ratelimiting disabled.
|
|
@ -31,7 +31,9 @@ from synapse.api.errors import (
|
||||||
MissingClientTokenError,
|
MissingClientTokenError,
|
||||||
)
|
)
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||||
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging import opentracing as opentracing
|
from synapse.logging import opentracing as opentracing
|
||||||
from synapse.storage.databases.main.registration import TokenLookupResult
|
from synapse.storage.databases.main.registration import TokenLookupResult
|
||||||
from synapse.types import StateMap, UserID
|
from synapse.types import StateMap, UserID
|
||||||
|
@ -474,7 +476,7 @@ class Auth:
|
||||||
now = self.hs.get_clock().time_msec()
|
now = self.hs.get_clock().time_msec()
|
||||||
return now < expiry
|
return now < expiry
|
||||||
|
|
||||||
def get_appservice_by_req(self, request):
|
def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
|
||||||
token = self.get_access_token_from_request(request)
|
token = self.get_access_token_from_request(request)
|
||||||
service = self.store.get_app_service_by_token(token)
|
service = self.store.get_app_service_by_token(token)
|
||||||
if not service:
|
if not service:
|
||||||
|
|
|
@ -22,6 +22,7 @@ import urllib.parse
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
|
@ -861,7 +862,7 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
async def validate_login(
|
async def validate_login(
|
||||||
self, login_submission: Dict[str, Any], ratelimit: bool = False,
|
self, login_submission: Dict[str, Any], ratelimit: bool = False,
|
||||||
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
|
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
|
||||||
"""Authenticates the user for the /login API
|
"""Authenticates the user for the /login API
|
||||||
|
|
||||||
Also used by the user-interactive auth flow to validate auth types which don't
|
Also used by the user-interactive auth flow to validate auth types which don't
|
||||||
|
@ -1004,7 +1005,7 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
async def _validate_userid_login(
|
async def _validate_userid_login(
|
||||||
self, username: str, login_submission: Dict[str, Any],
|
self, username: str, login_submission: Dict[str, Any],
|
||||||
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
|
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
|
||||||
"""Helper for validate_login
|
"""Helper for validate_login
|
||||||
|
|
||||||
Handles login, once we've mapped 3pids onto userids
|
Handles login, once we've mapped 3pids onto userids
|
||||||
|
@ -1082,7 +1083,7 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
async def check_password_provider_3pid(
|
async def check_password_provider_3pid(
|
||||||
self, medium: str, address: str, password: str
|
self, medium: str, address: str, password: str
|
||||||
) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]:
|
) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
|
||||||
"""Check if a password provider is able to validate a thirdparty login
|
"""Check if a password provider is able to validate a thirdparty login
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Awaitable, Callable, Dict, Optional
|
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
|
||||||
|
|
||||||
from synapse.api.errors import Codes, LoginError, SynapseError
|
from synapse.api.errors import Codes, LoginError, SynapseError
|
||||||
from synapse.api.ratelimiting import Ratelimiter
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
|
@ -30,6 +30,9 @@ from synapse.rest.client.v2_alpha._base import client_patterns
|
||||||
from synapse.rest.well_known import WellKnownBuilder
|
from synapse.rest.well_known import WellKnownBuilder
|
||||||
from synapse.types import JsonDict, UserID
|
from synapse.types import JsonDict, UserID
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,7 +45,7 @@ class LoginRestServlet(RestServlet):
|
||||||
JWT_TYPE_DEPRECATED = "m.login.jwt"
|
JWT_TYPE_DEPRECATED = "m.login.jwt"
|
||||||
APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
|
APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
|
@ -105,22 +108,27 @@ class LoginRestServlet(RestServlet):
|
||||||
return 200, {"flows": flows}
|
return 200, {"flows": flows}
|
||||||
|
|
||||||
async def on_POST(self, request: SynapseRequest):
|
async def on_POST(self, request: SynapseRequest):
|
||||||
self._address_ratelimiter.ratelimit(request.getClientIP())
|
|
||||||
|
|
||||||
login_submission = parse_json_object_from_request(request)
|
login_submission = parse_json_object_from_request(request)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
|
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
|
||||||
appservice = self.auth.get_appservice_by_req(request)
|
appservice = self.auth.get_appservice_by_req(request)
|
||||||
|
|
||||||
|
if appservice.is_rate_limited():
|
||||||
|
self._address_ratelimiter.ratelimit(request.getClientIP())
|
||||||
|
|
||||||
result = await self._do_appservice_login(login_submission, appservice)
|
result = await self._do_appservice_login(login_submission, appservice)
|
||||||
elif self.jwt_enabled and (
|
elif self.jwt_enabled and (
|
||||||
login_submission["type"] == LoginRestServlet.JWT_TYPE
|
login_submission["type"] == LoginRestServlet.JWT_TYPE
|
||||||
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
|
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
|
||||||
):
|
):
|
||||||
|
self._address_ratelimiter.ratelimit(request.getClientIP())
|
||||||
result = await self._do_jwt_login(login_submission)
|
result = await self._do_jwt_login(login_submission)
|
||||||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
||||||
|
self._address_ratelimiter.ratelimit(request.getClientIP())
|
||||||
result = await self._do_token_login(login_submission)
|
result = await self._do_token_login(login_submission)
|
||||||
else:
|
else:
|
||||||
|
self._address_ratelimiter.ratelimit(request.getClientIP())
|
||||||
result = await self._do_other_login(login_submission)
|
result = await self._do_other_login(login_submission)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise SynapseError(400, "Missing JSON keys.")
|
raise SynapseError(400, "Missing JSON keys.")
|
||||||
|
@ -159,7 +167,9 @@ class LoginRestServlet(RestServlet):
|
||||||
if not appservice.is_interested_in_user(qualified_user_id):
|
if not appservice.is_interested_in_user(qualified_user_id):
|
||||||
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
|
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
return await self._complete_login(qualified_user_id, login_submission)
|
return await self._complete_login(
|
||||||
|
qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited()
|
||||||
|
)
|
||||||
|
|
||||||
async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
|
async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
|
||||||
"""Handle non-token/saml/jwt logins
|
"""Handle non-token/saml/jwt logins
|
||||||
|
@ -194,6 +204,7 @@ class LoginRestServlet(RestServlet):
|
||||||
login_submission: JsonDict,
|
login_submission: JsonDict,
|
||||||
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
|
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
|
||||||
create_non_existent_users: bool = False,
|
create_non_existent_users: bool = False,
|
||||||
|
ratelimit: bool = True,
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, str]:
|
||||||
"""Called when we've successfully authed the user and now need to
|
"""Called when we've successfully authed the user and now need to
|
||||||
actually login them in (e.g. create devices). This gets called on
|
actually login them in (e.g. create devices). This gets called on
|
||||||
|
@ -208,6 +219,7 @@ class LoginRestServlet(RestServlet):
|
||||||
callback: Callback function to run after login.
|
callback: Callback function to run after login.
|
||||||
create_non_existent_users: Whether to create the user if they don't
|
create_non_existent_users: Whether to create the user if they don't
|
||||||
exist. Defaults to False.
|
exist. Defaults to False.
|
||||||
|
ratelimit: Whether to ratelimit the login request.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
result: Dictionary of account information after successful login.
|
result: Dictionary of account information after successful login.
|
||||||
|
@ -216,7 +228,8 @@ class LoginRestServlet(RestServlet):
|
||||||
# Before we actually log them in we check if they've already logged in
|
# Before we actually log them in we check if they've already logged in
|
||||||
# too often. This happens here rather than before as we don't
|
# too often. This happens here rather than before as we don't
|
||||||
# necessarily know the user before now.
|
# necessarily know the user before now.
|
||||||
self._account_ratelimiter.ratelimit(user_id.lower())
|
if ratelimit:
|
||||||
|
self._account_ratelimiter.ratelimit(user_id.lower())
|
||||||
|
|
||||||
if create_non_existent_users:
|
if create_non_existent_users:
|
||||||
canonical_uid = await self.auth_handler.check_user_exists(user_id)
|
canonical_uid = await self.auth_handler.check_user_exists(user_id)
|
||||||
|
|
Loading…
Reference in a new issue