mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-25 19:15:51 +03:00
parent
3af0672350
commit
a8eceb01e5
4 changed files with 27 additions and 10 deletions
1
changelog.d/8920.bugfix
Normal file
1
changelog.d/8920.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix login API to not ratelimit application services that have ratelimiting disabled.
|
|
@ -31,7 +31,9 @@ from synapse.api.errors import (
|
|||
MissingClientTokenError,
|
||||
)
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.events import EventBase
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging import opentracing as opentracing
|
||||
from synapse.storage.databases.main.registration import TokenLookupResult
|
||||
from synapse.types import StateMap, UserID
|
||||
|
@ -474,7 +476,7 @@ class Auth:
|
|||
now = self.hs.get_clock().time_msec()
|
||||
return now < expiry
|
||||
|
||||
def get_appservice_by_req(self, request):
|
||||
def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
|
||||
token = self.get_access_token_from_request(request)
|
||||
service = self.store.get_app_service_by_token(token)
|
||||
if not service:
|
||||
|
|
|
@ -22,6 +22,7 @@ import urllib.parse
|
|||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
|
@ -861,7 +862,7 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
async def validate_login(
|
||||
self, login_submission: Dict[str, Any], ratelimit: bool = False,
|
||||
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
|
||||
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
|
||||
"""Authenticates the user for the /login API
|
||||
|
||||
Also used by the user-interactive auth flow to validate auth types which don't
|
||||
|
@ -1004,7 +1005,7 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
async def _validate_userid_login(
|
||||
self, username: str, login_submission: Dict[str, Any],
|
||||
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
|
||||
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
|
||||
"""Helper for validate_login
|
||||
|
||||
Handles login, once we've mapped 3pids onto userids
|
||||
|
@ -1082,7 +1083,7 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
async def check_password_provider_3pid(
|
||||
self, medium: str, address: str, password: str
|
||||
) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]:
|
||||
) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
|
||||
"""Check if a password provider is able to validate a thirdparty login
|
||||
|
||||
Args:
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Awaitable, Callable, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
|
||||
|
||||
from synapse.api.errors import Codes, LoginError, SynapseError
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
|
@ -30,6 +30,9 @@ from synapse.rest.client.v2_alpha._base import client_patterns
|
|||
from synapse.rest.well_known import WellKnownBuilder
|
||||
from synapse.types import JsonDict, UserID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -42,7 +45,7 @@ class LoginRestServlet(RestServlet):
|
|||
JWT_TYPE_DEPRECATED = "m.login.jwt"
|
||||
APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
|
||||
|
@ -105,22 +108,27 @@ class LoginRestServlet(RestServlet):
|
|||
return 200, {"flows": flows}
|
||||
|
||||
async def on_POST(self, request: SynapseRequest):
|
||||
self._address_ratelimiter.ratelimit(request.getClientIP())
|
||||
|
||||
login_submission = parse_json_object_from_request(request)
|
||||
|
||||
try:
|
||||
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
|
||||
appservice = self.auth.get_appservice_by_req(request)
|
||||
|
||||
if appservice.is_rate_limited():
|
||||
self._address_ratelimiter.ratelimit(request.getClientIP())
|
||||
|
||||
result = await self._do_appservice_login(login_submission, appservice)
|
||||
elif self.jwt_enabled and (
|
||||
login_submission["type"] == LoginRestServlet.JWT_TYPE
|
||||
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
|
||||
):
|
||||
self._address_ratelimiter.ratelimit(request.getClientIP())
|
||||
result = await self._do_jwt_login(login_submission)
|
||||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
||||
self._address_ratelimiter.ratelimit(request.getClientIP())
|
||||
result = await self._do_token_login(login_submission)
|
||||
else:
|
||||
self._address_ratelimiter.ratelimit(request.getClientIP())
|
||||
result = await self._do_other_login(login_submission)
|
||||
except KeyError:
|
||||
raise SynapseError(400, "Missing JSON keys.")
|
||||
|
@ -159,7 +167,9 @@ class LoginRestServlet(RestServlet):
|
|||
if not appservice.is_interested_in_user(qualified_user_id):
|
||||
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
|
||||
|
||||
return await self._complete_login(qualified_user_id, login_submission)
|
||||
return await self._complete_login(
|
||||
qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited()
|
||||
)
|
||||
|
||||
async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
|
||||
"""Handle non-token/saml/jwt logins
|
||||
|
@ -194,6 +204,7 @@ class LoginRestServlet(RestServlet):
|
|||
login_submission: JsonDict,
|
||||
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
|
||||
create_non_existent_users: bool = False,
|
||||
ratelimit: bool = True,
|
||||
) -> Dict[str, str]:
|
||||
"""Called when we've successfully authed the user and now need to
|
||||
actually login them in (e.g. create devices). This gets called on
|
||||
|
@ -208,6 +219,7 @@ class LoginRestServlet(RestServlet):
|
|||
callback: Callback function to run after login.
|
||||
create_non_existent_users: Whether to create the user if they don't
|
||||
exist. Defaults to False.
|
||||
ratelimit: Whether to ratelimit the login request.
|
||||
|
||||
Returns:
|
||||
result: Dictionary of account information after successful login.
|
||||
|
@ -216,7 +228,8 @@ class LoginRestServlet(RestServlet):
|
|||
# Before we actually log them in we check if they've already logged in
|
||||
# too often. This happens here rather than before as we don't
|
||||
# necessarily know the user before now.
|
||||
self._account_ratelimiter.ratelimit(user_id.lower())
|
||||
if ratelimit:
|
||||
self._account_ratelimiter.ratelimit(user_id.lower())
|
||||
|
||||
if create_non_existent_users:
|
||||
canonical_uid = await self.auth_handler.check_user_exists(user_id)
|
||||
|
|
Loading…
Reference in a new issue