This commit is contained in:
Andrew Morgan 2020-06-01 17:04:52 +01:00
parent f6203a60e0
commit 27e4157727
14 changed files with 169 additions and 158 deletions

View file

@ -16,6 +16,10 @@ from collections import OrderedDict
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple
from synapse.api.errors import LimitExceededError from synapse.api.errors import LimitExceededError
from synapse.util import Clock
import logging
logger = logging.getLogger(__name__)
class Ratelimiter(object): class Ratelimiter(object):
@ -23,24 +27,30 @@ class Ratelimiter(object):
Ratelimit actions marked by arbitrary keys. Ratelimit actions marked by arbitrary keys.
Args: Args:
clock: A homeserver clock, for retrieving the current time
rate_hz: The long term number of actions that can be performed in a second. rate_hz: The long term number of actions that can be performed in a second.
burst_count: How many actions that can be performed before being limited. burst_count: How many actions that can be performed before being limited.
""" """
def __init__(self, rate_hz: float, burst_count: int): def __init__(self, clock: Clock, rate_hz: float, burst_count: int):
self.clock = clock
self.rate_hz = rate_hz
self.burst_count = burst_count
# A ordered dictionary keeping track of actions, when they were last # A ordered dictionary keeping track of actions, when they were last
# performed and how often. Each entry is a mapping from a key of arbitrary type # performed and how often. Each entry is a mapping from a key of arbitrary type
# to a tuple representing: # to a tuple representing:
# * How many times an action has occurred since a point in time # * How many times an action has occurred since a point in time
# * That point in time # * The point in time
self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int]] # * The rate_hz of this particular entry. This can vary per-request
self.rate_hz = rate_hz self.actions = (
self.burst_count = burst_count OrderedDict()
) # type: OrderedDict[Any, Tuple[float, int, Optional[float]]]
def can_do_action( def can_do_action(
self, self,
key: Any, key: Any,
time_now_s: int, time_now_s: Optional[int] = None,
update: bool = True, update: bool = True,
rate_hz: Optional[float] = None, rate_hz: Optional[float] = None,
burst_count: Optional[int] = None, burst_count: Optional[int] = None,
@ -50,7 +60,8 @@ class Ratelimiter(object):
Args: Args:
key: The key we should use when rate limiting. Can be a user ID key: The key we should use when rate limiting. Can be a user ID
(when sending events), an IP address, etc. (when sending events), an IP address, etc.
time_now_s: The time now time_now_s: The current time. Optional, defaults to the current time according
to self.clock. Pretty much only used for tests.
update: Whether to count this check as performing the action update: Whether to count this check as performing the action
rate_hz: The long term number of actions that can be performed in a second. rate_hz: The long term number of actions that can be performed in a second.
Overrides the value set during instantiation if set. Overrides the value set during instantiation if set.
@ -64,14 +75,15 @@ class Ratelimiter(object):
-1 if a rate_hz has not been defined for this Ratelimiter -1 if a rate_hz has not been defined for this Ratelimiter
""" """
# Override default values if set # Override default values if set
rate_hz = rate_hz or self.rate_hz time_now_s = time_now_s if time_now_s is not None else self.clock.time()
burst_count = burst_count or self.burst_count rate_hz = rate_hz if rate_hz is not None else self.rate_hz
burst_count = burst_count if burst_count is not None else self.burst_count
# Remove any expired entries # Remove any expired entries
self._prune_message_counts(time_now_s, rate_hz) self._prune_message_counts(time_now_s)
# Check if there is an existing count entry for this key # Check if there is an existing count entry for this key
action_count, time_start, = self.actions.get(key, (0.0, time_now_s)) action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, None))
# Check whether performing another action is allowed # Check whether performing another action is allowed
time_delta = time_now_s - time_start time_delta = time_now_s - time_start
@ -90,7 +102,10 @@ class Ratelimiter(object):
action_count += 1.0 action_count += 1.0
if update: if update:
self.actions[key] = (action_count, time_start) self.actions[key] = (action_count, time_start, rate_hz)
logger.info("rate and burst: %s %s. performed_count: %s, allowed: %s", rate_hz,
burst_count, performed_count, allowed)
# Figure out the time when an action can be performed again # Figure out the time when an action can be performed again
if self.rate_hz > 0: if self.rate_hz > 0:
@ -105,18 +120,17 @@ class Ratelimiter(object):
return allowed, time_allowed return allowed, time_allowed
def _prune_message_counts(self, time_now_s: int, rate_hz: float): def _prune_message_counts(self, time_now_s: int):
"""Remove message count entries that have not exceeded their defined """Remove message count entries that have not exceeded their defined
rate_hz limit rate_hz limit
Args: Args:
time_now_s: The current time time_now_s: The current time
rate_hz: The long term number of actions that can be performed in a second.
""" """
# We create a copy of the key list here as the dictionary is modified during # We create a copy of the key list here as the dictionary is modified during
# the loop # the loop
for key in list(self.actions.keys()): for key in list(self.actions.keys()):
action_count, time_start = self.actions[key] action_count, time_start, rate_hz = self.actions[key]
# Rate limit = "seconds since we started limiting this action" * rate_hz # Rate limit = "seconds since we started limiting this action" * rate_hz
# If this limit has not been exceeded, wipe our record of this action # If this limit has not been exceeded, wipe our record of this action
@ -129,7 +143,7 @@ class Ratelimiter(object):
def ratelimit( def ratelimit(
self, self,
key: Any, key: Any,
time_now_s: int, time_now_s: Optional[int] = None,
update: bool = True, update: bool = True,
rate_hz: Optional[float] = None, rate_hz: Optional[float] = None,
burst_count: Optional[int] = None, burst_count: Optional[int] = None,
@ -138,7 +152,8 @@ class Ratelimiter(object):
Args: Args:
key: An arbitrary key used to classify an action key: An arbitrary key used to classify an action
time_now_s: The current time time_now_s: The current time. Optional, defaults to the current time according
to self.clock. Pretty much only used for tests.
update: Whether to count this check as performing the action update: Whether to count this check as performing the action
rate_hz: The long term number of actions that can be performed in a second. rate_hz: The long term number of actions that can be performed in a second.
Overrides the value set during instantiation if set. Overrides the value set during instantiation if set.
@ -150,8 +165,9 @@ class Ratelimiter(object):
milliseconds until the action can be performed again milliseconds until the action can be performed again
""" """
# Override default values if set # Override default values if set
rate_hz = rate_hz or self.rate_hz time_now_s = time_now_s if time_now_s is not None else self.clock.time()
burst_count = burst_count or self.burst_count rate_hz = rate_hz if rate_hz is not None else self.rate_hz
burst_count = burst_count if burst_count is not None else self.burst_count
allowed, time_allowed = self.can_do_action( allowed, time_allowed = self.can_do_action(
key, time_now_s, update=update, rate_hz=rate_hz, burst_count=burst_count key, time_now_s, update=update, rate_hz=rate_hz, burst_count=burst_count

View file

@ -20,6 +20,7 @@ from twisted.internet import defer
import synapse.types import synapse.types
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.types import UserID from synapse.types import UserID
from synapse.api.ratelimiting import Ratelimiter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -46,11 +47,20 @@ class BaseHandler(object):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.hs = hs self.hs = hs
self.request_ratelimiter = hs.get_request_ratelimiter() # The rate_hz and burst_count are overridden on a per-user basis
self.request_ratelimiter = Ratelimiter(clock=self.clock, rate_hz=0, burst_count=0)
self._rc_message = self.hs.config.rc_message self._rc_message = self.hs.config.rc_message
# If special admin redaction ratelimiting is disabled, this will be None # Check whether ratelimiting room admin message redaction is enabled
self.admin_redaction_ratelimiter = hs.get_admin_redaction_ratelimiter() # by the presence of rate limits in the config
if self.hs.config.rc_admin_redaction:
self.admin_redaction_ratelimiter = Ratelimiter(
clock=self.clock,
rate_hz=self.hs.config.rc_admin_redaction.per_second,
burst_count=self.hs.config.rc_admin_redaction.burst_count,
)
else:
self.admin_redaction_ratelimiter = None
self.server_name = hs.hostname self.server_name = hs.hostname

View file

@ -110,6 +110,7 @@ class AuthHandler(BaseHandler):
# as per `rc_login.failed_attempts`. # as per `rc_login.failed_attempts`.
# XXX: Should this be hs.get_login_failed_attempts_ratelimiter? # XXX: Should this be hs.get_login_failed_attempts_ratelimiter?
self._failed_uia_attempts_ratelimiter = Ratelimiter( self._failed_uia_attempts_ratelimiter = Ratelimiter(
clock=self.clock,
rate_hz=self.hs.config.rc_login_failed_attempts.per_second, rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count, burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
) )
@ -200,9 +201,7 @@ class AuthHandler(BaseHandler):
user_id = requester.user.to_string() user_id = requester.user.to_string()
# Check if we should be ratelimited due to too many previous failed attempts # Check if we should be ratelimited due to too many previous failed attempts
self._failed_uia_attempts_ratelimiter.ratelimit( self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
user_id, time_now_s=self._clock.time(), update=False,
)
# build a list of supported flows # build a list of supported flows
flows = [[login_type] for login_type in self._supported_ui_auth_types] flows = [[login_type] for login_type in self._supported_ui_auth_types]
@ -212,10 +211,8 @@ class AuthHandler(BaseHandler):
flows, request, request_body, clientip, description flows, request, request_body, clientip, description
) )
except LoginError: except LoginError:
# Update the ratelimite to say we failed (`can_do_action` doesn't raise). # Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
self._failed_uia_attempts_ratelimiter.can_do_action( self._failed_uia_attempts_ratelimiter.can_do_action(user_id)
user_id, time_now_s=self._clock.time(), update=True,
)
raise raise
# find the completed login type # find the completed login type

View file

@ -427,9 +427,7 @@ class RegistrationHandler(BaseHandler):
time_now = self.clock.time() time_now = self.clock.time()
self.ratelimiter.ratelimit( self.ratelimiter.ratelimit(address)
address, time_now_s=time_now,
)
def register_with_store( def register_with_store(
self, self,

View file

@ -27,6 +27,7 @@ from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID from synapse.types import UserID
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from synapse.api.ratelimiting import Ratelimiter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -86,10 +87,28 @@ class LoginRestServlet(RestServlet):
self.auth_handler = self.hs.get_auth_handler() self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self._clock = hs.get_clock()
self._well_known_builder = WellKnownBuilder(hs) self._well_known_builder = WellKnownBuilder(hs)
self._account_ratelimiter = hs.get_login_ratelimiter() self._address_ratelimiter = Ratelimiter(
self._failed_attempts_ratelimiter = hs.get_login_failed_attempts_ratelimiter() clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_address.per_second,
burst_count=self.hs.config.rc_login_address.burst_count,
)
self._account_ratelimiter = Ratelimiter(
clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_account.per_second,
burst_count=self.hs.config.rc_login_account.burst_count,
)
print(
"Creating fail ratelimiter: %s %s" % (
self.hs.config.rc_login_failed_attempts.per_second,
self.hs.config.rc_login_failed_attempts.burst_count,
),
)
self._failed_attempts_ratelimiter = Ratelimiter(
clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
def on_GET(self, request): def on_GET(self, request):
flows = [] flows = []
@ -127,9 +146,7 @@ class LoginRestServlet(RestServlet):
return 200, {} return 200, {}
async def on_POST(self, request): async def on_POST(self, request):
self._account_ratelimiter.ratelimit( self._address_ratelimiter.ratelimit(request.getClientIP())
request.getClientIP(), time_now_s=self.hs.clock.time(), update=True,
)
login_submission = parse_json_object_from_request(request) login_submission = parse_json_object_from_request(request)
try: try:
@ -197,9 +214,7 @@ class LoginRestServlet(RestServlet):
# We also apply account rate limiting using the 3PID as a key, as # We also apply account rate limiting using the 3PID as a key, as
# otherwise using 3PID bypasses the ratelimiting based on user ID. # otherwise using 3PID bypasses the ratelimiting based on user ID.
self._failed_attempts_ratelimiter.ratelimit( self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False)
(medium, address), time_now_s=self._clock.time(), update=False,
)
# Check for login providers that support 3pid login types # Check for login providers that support 3pid login types
( (
@ -233,9 +248,7 @@ class LoginRestServlet(RestServlet):
# If it returned None but the 3PID was bound then we won't hit # If it returned None but the 3PID was bound then we won't hit
# this code path, which is fine as then the per-user ratelimit # this code path, which is fine as then the per-user ratelimit
# will kick in below. # will kick in below.
self._failed_attempts_ratelimiter.can_do_action( self._failed_attempts_ratelimiter.can_do_action((medium, address))
(medium, address), time_now_s=self._clock.time(), update=True,
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)
identifier = {"type": "m.id.user", "user": user_id} identifier = {"type": "m.id.user", "user": user_id}
@ -253,9 +266,7 @@ class LoginRestServlet(RestServlet):
qualified_user_id = UserID(identifier["user"], self.hs.hostname).to_string() qualified_user_id = UserID(identifier["user"], self.hs.hostname).to_string()
# Check if we've hit the failed ratelimit (but don't update it) # Check if we've hit the failed ratelimit (but don't update it)
self._failed_attempts_ratelimiter.ratelimit( self._failed_attempts_ratelimiter.ratelimit(qualified_user_id.lower(), update=False)
qualified_user_id.lower(), time_now_s=self._clock.time(), update=False,
)
try: try:
canonical_user_id, callback = await self.auth_handler.validate_login( canonical_user_id, callback = await self.auth_handler.validate_login(
@ -266,9 +277,7 @@ class LoginRestServlet(RestServlet):
# limiter. Using `can_do_action` avoids us raising a ratelimit # limiter. Using `can_do_action` avoids us raising a ratelimit
# exception and masking the LoginError. The actual ratelimiting # exception and masking the LoginError. The actual ratelimiting
# should have happened above. # should have happened above.
self._failed_attempts_ratelimiter.can_do_action( self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower())
qualified_user_id.lower(), time_now_s=self._clock.time(), update=True,
)
raise raise
result = await self._complete_login( result = await self._complete_login(
@ -301,9 +310,7 @@ class LoginRestServlet(RestServlet):
# Before we actually log them in we check if they've already logged in # Before we actually log them in we check if they've already logged in
# too often. This happens here rather than before as we don't # too often. This happens here rather than before as we don't
# necessarily know the user before now. # necessarily know the user before now.
self._account_ratelimiter.ratelimit( self._account_ratelimiter.ratelimit(user_id.lower())
user_id.lower(), time_now_s=self._clock.time(), update=True,
)
if create_non_existant_users: if create_non_existant_users:
user_id = await self.auth_handler.check_user_exists(user_id) user_id = await self.auth_handler.check_user_exists(user_id)

View file

@ -396,16 +396,7 @@ class RegisterRestServlet(RestServlet):
client_addr = request.getClientIP() client_addr = request.getClientIP()
time_now = self.clock.time() self.ratelimiter.ratelimit(client_addr, update=False)
allowed, time_allowed = self.ratelimiter.can_do_action(
client_addr, time_now_s=time_now, update=False,
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now))
)
kind = b"user" kind = b"user"
if b"kind" in request.args: if b"kind" in request.args:

View file

@ -243,28 +243,12 @@ class HomeServer(object):
self.clock = Clock(reactor) self.clock = Clock(reactor)
self.distributor = Distributor() self.distributor = Distributor()
# The rate_hz and burst_count is overridden on a per-user basis
self.request_ratelimiter = Ratelimiter(rate_hz=0, burst_count=0,)
if config.rc_admin_redaction:
self.admin_redaction_ratelimiter = Ratelimiter(
rate_hz=config.rc_admin_redaction.per_second,
burst_count=config.rc_admin_redaction.burst_count,
)
else:
self.admin_redaction_ratelimiter = None
self.registration_ratelimiter = Ratelimiter( self.registration_ratelimiter = Ratelimiter(
clock=self.clock,
rate_hz=config.rc_registration.per_second, rate_hz=config.rc_registration.per_second,
burst_count=config.rc_registration.burst_count, burst_count=config.rc_registration.burst_count,
) )
self.login_ratelimiter = Ratelimiter(
rate_hz=config.rc_login_account.per_second,
burst_count=config.rc_login_account.burst_count,
)
self.login_failed_attempts_ratelimiter = Ratelimiter(
rate_hz=config.rc_login_failed_attempts.per_second,
burst_count=config.rc_login_failed_attempts.burst_count,
)
self.datastores = None self.datastores = None
@ -334,21 +318,9 @@ class HomeServer(object):
def get_distributor(self): def get_distributor(self):
return self.distributor return self.distributor
def get_request_ratelimiter(self) -> Ratelimiter:
return self.request_ratelimiter
def get_registration_ratelimiter(self) -> Ratelimiter: def get_registration_ratelimiter(self) -> Ratelimiter:
return self.registration_ratelimiter return self.registration_ratelimiter
def get_admin_redaction_ratelimiter(self) -> Optional[Ratelimiter]:
return self.admin_redaction_ratelimiter
def get_login_ratelimiter(self) -> Ratelimiter:
return self.login_ratelimiter
def get_login_failed_attempts_ratelimiter(self) -> Ratelimiter:
return self.login_failed_attempts_ratelimiter
def build_federation_client(self): def build_federation_client(self):
return FederationClient(self) return FederationClient(self)

View file

@ -5,7 +5,7 @@ from tests import unittest
class TestRatelimiter(unittest.TestCase): class TestRatelimiter(unittest.TestCase):
def test_allowed(self): def test_allowed(self):
limiter = Ratelimiter(rate_hz=0.1, burst_count=1) limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
allowed, time_allowed = limiter.can_do_action(key="test_id", time_now_s=0) allowed, time_allowed = limiter.can_do_action(key="test_id", time_now_s=0)
self.assertTrue(allowed) self.assertTrue(allowed)
self.assertEquals(10.0, time_allowed) self.assertEquals(10.0, time_allowed)
@ -19,7 +19,7 @@ class TestRatelimiter(unittest.TestCase):
self.assertEquals(20.0, time_allowed) self.assertEquals(20.0, time_allowed)
def test_pruning(self): def test_pruning(self):
limiter = Ratelimiter(rate_hz=0.1, burst_count=1) limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
_, _ = limiter.can_do_action(key="test_id_1", time_now_s=0) _, _ = limiter.can_do_action(key="test_id_1", time_now_s=0)
self.assertIn("test_id_1", limiter.actions) self.assertIn("test_id_1", limiter.actions)

View file

@ -14,12 +14,13 @@
# limitations under the License. # limitations under the License.
from mock import Mock, NonCallableMock from mock import Mock, patch
from twisted.internet import defer from twisted.internet import defer
import synapse.types import synapse.types
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.handlers.profile import MasterProfileHandler from synapse.handlers.profile import MasterProfileHandler
from synapse.types import UserID from synapse.types import UserID
@ -55,17 +56,15 @@ class ProfileTestCase(unittest.TestCase):
federation_client=self.mock_federation, federation_client=self.mock_federation,
federation_server=Mock(), federation_server=Mock(),
federation_registry=self.mock_registry, federation_registry=self.mock_registry,
request_ratelimiter=NonCallableMock(
spec_set=["can_do_action", "ratelimit"]
),
login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]),
) )
self.request_ratelimiter = hs.get_request_ratelimiter() # Patch Ratelimiter to allow all requests
self.request_ratelimiter.can_do_action.return_value = (True, 0) patch.object(
Ratelimiter, "can_do_action", new_callable=lambda *args, **kwargs: (True, 0.0)
self.login_ratelimiter = hs.get_login_ratelimiter() )
self.login_ratelimiter.can_do_action.return_value = (True, 0) patch.object(
Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None
)
self.store = hs.get_datastore() self.store = hs.get_datastore()

View file

@ -13,8 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from mock import Mock, NonCallableMock from mock import Mock, patch
from synapse.api.ratelimiting import Ratelimiter
from tests.replication._base import BaseStreamTestCase from tests.replication._base import BaseStreamTestCase
@ -23,18 +24,15 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase):
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
federation_client=Mock(), federation_client=Mock(),
request_ratelimiter=NonCallableMock(
spec_set=["can_do_action", "ratelimit"]
),
login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]),
) )
# Prevent ratelimiting # Patch Ratelimiter to allow all requests
self.request_ratelimiter = hs.get_request_ratelimiter() patch.object(
self.request_ratelimiter.can_do_action.return_value = (True, 0) Ratelimiter, "can_do_action", new_callable=lambda *args, **kwargs: (True, 0.0)
)
self.login_ratelimiter = hs.get_login_ratelimiter() patch.object(
self.login_ratelimiter.can_do_action.return_value = (True, 0) Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None
)
return hs return hs

View file

@ -15,10 +15,11 @@
""" Tests REST events for /events paths.""" """ Tests REST events for /events paths."""
from mock import Mock, NonCallableMock from mock import Mock, patch
import synapse.rest.admin import synapse.rest.admin
from synapse.rest.client.v1 import events, login, room from synapse.rest.client.v1 import events, login, room
from synapse.api.ratelimiting import Ratelimiter
from tests import unittest from tests import unittest
@ -42,17 +43,15 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
config=config, config=config,
request_ratelimiter=NonCallableMock(
# rate_hz and burst_count are overridden in BaseHandler
spec_set=["can_do_action", "ratelimit", "rate_hz", "burst_count"]
),
login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]),
) )
self.request_ratelimiter = hs.get_request_ratelimiter()
self.request_ratelimiter.can_do_action.return_value = (True, 0)
self.login_ratelimiter = hs.get_login_ratelimiter() # Patch Ratelimiter to allow all requests
self.login_ratelimiter.can_do_action.return_value = (True, 0) patch.object(
Ratelimiter, "can_do_action", new_callable=lambda *args, **kwargs: (True, 0.0)
)
patch.object(
Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None
)
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()

View file

@ -7,6 +7,7 @@ import synapse.rest.admin
from synapse.rest.client.v1 import login, logout from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import devices from synapse.rest.client.v2_alpha import devices
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
from synapse.api.ratelimiting import Ratelimiter
from tests import unittest from tests import unittest
from tests.unittest import override_config from tests.unittest import override_config
@ -26,7 +27,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver() self.hs = self.setup_test_homeserver()
self.hs.config.enable_registration = True self.hs.config.enable_registration = True
self.hs.config.registrations_require_3pid = [] self.hs.config.registrations_require_3pid = []
@ -35,10 +35,17 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
return self.hs return self.hs
@override_config(
{
"rc_login": {
"account": {
"per_second": 0.17,
"burst_count": 5,
}
}
}
)
def test_POST_ratelimiting_per_address(self): def test_POST_ratelimiting_per_address(self):
self.hs.get_login_ratelimiter().burst_count = 5
self.hs.get_login_ratelimiter().rate_hz = 0.17
# Create different users so we're sure not to be bothered by the per-user # Create different users so we're sure not to be bothered by the per-user
# ratelimiter. # ratelimiter.
for i in range(0, 6): for i in range(0, 6):
@ -77,10 +84,17 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
@override_config(
{
"rc_login": {
"account": {
"per_second": 0.17,
"burst_count": 5,
}
}
}
)
def test_POST_ratelimiting_per_account(self): def test_POST_ratelimiting_per_account(self):
self.hs.get_login_ratelimiter().burst_count = 5
self.hs.get_login_ratelimiter().rate_hz = 0.17
self.register_user("kermit", "monkey") self.register_user("kermit", "monkey")
for i in range(0, 6): for i in range(0, 6):
@ -116,10 +130,23 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
@override_config(
{
"rc_login": {
# Prevent the generic login ratelimiter from raising first
"address": {
"per_second": 1000,
"burst_count": 1000,
},
"failed_attempts": {
"per_second": 0.17,
"burst_count": 5,
}
}
}
)
@unittest.DEBUG
def test_POST_ratelimiting_per_account_failed_attempts(self): def test_POST_ratelimiting_per_account_failed_attempts(self):
self.hs.get_login_failed_attempts_ratelimiter().burst_count = 5
self.hs.get_login_failed_attempts_ratelimiter().rate_hz = 0.17
self.register_user("kermit", "monkey") self.register_user("kermit", "monkey")
for i in range(0, 6): for i in range(0, 6):
@ -128,8 +155,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"}, "identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey", "password": "notamonkey",
} }
request_data = json.dumps(params) request, channel = self.make_request(b"POST", LOGIN_URL, params)
request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
self.render(request) self.render(request)
if i == 5: if i == 5:
@ -149,7 +175,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"}, "identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey", "password": "notamonkey",
} }
request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params) request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request) self.render(request)

View file

@ -20,12 +20,13 @@
import json import json
from mock import Mock, NonCallableMock from mock import Mock, patch
from six.moves.urllib import parse as urlparse from six.moves.urllib import parse as urlparse
from twisted.internet import defer from twisted.internet import defer
import synapse.rest.admin import synapse.rest.admin
from synapse.api.ratelimiting import Ratelimiter
from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus from synapse.handlers.pagination import PurgeStatus
from synapse.rest.client.v1 import directory, login, profile, room from synapse.rest.client.v1 import directory, login, profile, room
@ -49,16 +50,15 @@ class RoomBase(unittest.HomeserverTestCase):
"red", "red",
http_client=None, http_client=None,
federation_client=Mock(), federation_client=Mock(),
request_ratelimiter=NonCallableMock(
spec_set=["can_do_action", "ratelimit"]
),
login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]),
) )
self.request_ratelimiter = self.hs.get_request_ratelimiter()
self.request_ratelimiter.can_do_action.return_value = (True, 0)
self.login_ratelimiter = self.hs.get_login_ratelimiter() # Patch Ratelimiter to allow all requests
self.login_ratelimiter.can_do_action.return_value = (True, 0) patch.object(
Ratelimiter, "can_do_action", new_callable=lambda *args, **kwargs: (True, 0.0)
)
patch.object(
Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None
)
self.hs.get_federation_handler = Mock(return_value=Mock()) self.hs.get_federation_handler = Mock(return_value=Mock())

View file

@ -16,12 +16,13 @@
"""Tests REST events for /rooms paths.""" """Tests REST events for /rooms paths."""
from mock import Mock, NonCallableMock from mock import Mock, NonCallableMock, patch
from twisted.internet import defer from twisted.internet import defer
from synapse.rest.client.v1 import room from synapse.rest.client.v1 import room
from synapse.types import UserID from synapse.types import UserID
from synapse.api.ratelimiting import Ratelimiter
from tests import unittest from tests import unittest
@ -42,20 +43,18 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
"red", "red",
http_client=None, http_client=None,
federation_client=Mock(), federation_client=Mock(),
request_ratelimiter=NonCallableMock( )
spec_set=["can_do_action", "ratelimit"]
), # Patch Ratelimiter to allow all requests
login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), patch.object(
Ratelimiter, "can_do_action", new_callable=lambda *args, **kwargs: (True, 0.0)
)
patch.object(
Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None
) )
self.event_source = hs.get_event_sources().sources["typing"] self.event_source = hs.get_event_sources().sources["typing"]
self.request_ratelimiter = hs.get_request_ratelimiter()
self.request_ratelimiter.can_do_action.return_value = (True, 0)
self.login_ratelimiter = hs.get_login_ratelimiter()
self.login_ratelimiter.can_do_action.return_value = (True, 0)
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def get_user_by_access_token(token=None, allow_guest=False): def get_user_by_access_token(token=None, allow_guest=False):