Honour AS ratelimit settings for /login requests (#8920)

Fixes #8846.
This commit is contained in:
Erik Johnston 2020-12-11 16:33:31 +00:00 committed by GitHub
parent 3af0672350
commit a8eceb01e5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 27 additions and 10 deletions

1
changelog.d/8920.bugfix Normal file
View file

@ -0,0 +1 @@
Fix login API to not ratelimit application services that have ratelimiting disabled.

View file

@ -31,7 +31,9 @@ from synapse.api.errors import (
MissingClientTokenError, MissingClientTokenError,
) )
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.appservice import ApplicationService
from synapse.events import EventBase from synapse.events import EventBase
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing as opentracing from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import StateMap, UserID from synapse.types import StateMap, UserID
@ -474,7 +476,7 @@ class Auth:
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
return now < expiry return now < expiry
def get_appservice_by_req(self, request): def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
token = self.get_access_token_from_request(request) token = self.get_access_token_from_request(request)
service = self.store.get_app_service_by_token(token) service = self.store.get_app_service_by_token(token)
if not service: if not service:

View file

@ -22,6 +22,7 @@ import urllib.parse
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Awaitable,
Callable, Callable,
Dict, Dict,
Iterable, Iterable,
@ -861,7 +862,7 @@ class AuthHandler(BaseHandler):
async def validate_login( async def validate_login(
self, login_submission: Dict[str, Any], ratelimit: bool = False, self, login_submission: Dict[str, Any], ratelimit: bool = False,
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]: ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Authenticates the user for the /login API """Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate auth types which don't Also used by the user-interactive auth flow to validate auth types which don't
@ -1004,7 +1005,7 @@ class AuthHandler(BaseHandler):
async def _validate_userid_login( async def _validate_userid_login(
self, username: str, login_submission: Dict[str, Any], self, username: str, login_submission: Dict[str, Any],
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]: ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Helper for validate_login """Helper for validate_login
Handles login, once we've mapped 3pids onto userids Handles login, once we've mapped 3pids onto userids
@ -1082,7 +1083,7 @@ class AuthHandler(BaseHandler):
async def check_password_provider_3pid( async def check_password_provider_3pid(
self, medium: str, address: str, password: str self, medium: str, address: str, password: str
) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]: ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Check if a password provider is able to validate a thirdparty login """Check if a password provider is able to validate a thirdparty login
Args: Args:

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Awaitable, Callable, Dict, Optional from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
@ -30,6 +30,9 @@ from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder from synapse.rest.well_known import WellKnownBuilder
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -42,7 +45,7 @@ class LoginRestServlet(RestServlet):
JWT_TYPE_DEPRECATED = "m.login.jwt" JWT_TYPE_DEPRECATED = "m.login.jwt"
APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service" APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
@ -105,22 +108,27 @@ class LoginRestServlet(RestServlet):
return 200, {"flows": flows} return 200, {"flows": flows}
async def on_POST(self, request: SynapseRequest): async def on_POST(self, request: SynapseRequest):
self._address_ratelimiter.ratelimit(request.getClientIP())
login_submission = parse_json_object_from_request(request) login_submission = parse_json_object_from_request(request)
try: try:
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
appservice = self.auth.get_appservice_by_req(request) appservice = self.auth.get_appservice_by_req(request)
if appservice.is_rate_limited():
self._address_ratelimiter.ratelimit(request.getClientIP())
result = await self._do_appservice_login(login_submission, appservice) result = await self._do_appservice_login(login_submission, appservice)
elif self.jwt_enabled and ( elif self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE login_submission["type"] == LoginRestServlet.JWT_TYPE
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
): ):
self._address_ratelimiter.ratelimit(request.getClientIP())
result = await self._do_jwt_login(login_submission) result = await self._do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
self._address_ratelimiter.ratelimit(request.getClientIP())
result = await self._do_token_login(login_submission) result = await self._do_token_login(login_submission)
else: else:
self._address_ratelimiter.ratelimit(request.getClientIP())
result = await self._do_other_login(login_submission) result = await self._do_other_login(login_submission)
except KeyError: except KeyError:
raise SynapseError(400, "Missing JSON keys.") raise SynapseError(400, "Missing JSON keys.")
@ -159,7 +167,9 @@ class LoginRestServlet(RestServlet):
if not appservice.is_interested_in_user(qualified_user_id): if not appservice.is_interested_in_user(qualified_user_id):
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN) raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
return await self._complete_login(qualified_user_id, login_submission) return await self._complete_login(
qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited()
)
async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]: async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
"""Handle non-token/saml/jwt logins """Handle non-token/saml/jwt logins
@ -194,6 +204,7 @@ class LoginRestServlet(RestServlet):
login_submission: JsonDict, login_submission: JsonDict,
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None, callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
create_non_existent_users: bool = False, create_non_existent_users: bool = False,
ratelimit: bool = True,
) -> Dict[str, str]: ) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to """Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on actually login them in (e.g. create devices). This gets called on
@ -208,6 +219,7 @@ class LoginRestServlet(RestServlet):
callback: Callback function to run after login. callback: Callback function to run after login.
create_non_existent_users: Whether to create the user if they don't create_non_existent_users: Whether to create the user if they don't
exist. Defaults to False. exist. Defaults to False.
ratelimit: Whether to ratelimit the login request.
Returns: Returns:
result: Dictionary of account information after successful login. result: Dictionary of account information after successful login.
@ -216,7 +228,8 @@ class LoginRestServlet(RestServlet):
# Before we actually log them in we check if they've already logged in # Before we actually log them in we check if they've already logged in
# too often. This happens here rather than before as we don't # too often. This happens here rather than before as we don't
# necessarily know the user before now. # necessarily know the user before now.
self._account_ratelimiter.ratelimit(user_id.lower()) if ratelimit:
self._account_ratelimiter.ratelimit(user_id.lower())
if create_non_existent_users: if create_non_existent_users:
canonical_uid = await self.auth_handler.check_user_exists(user_id) canonical_uid = await self.auth_handler.check_user_exists(user_id)