diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 79b7631172..9cd7f5cda2 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -16,6 +16,10 @@ from collections import OrderedDict from typing import Any, Optional, Tuple from synapse.api.errors import LimitExceededError +from synapse.util import Clock +import logging + +logger = logging.getLogger(__name__) class Ratelimiter(object): @@ -23,24 +27,30 @@ class Ratelimiter(object): Ratelimit actions marked by arbitrary keys. Args: + clock: A homeserver clock, for retrieving the current time rate_hz: The long term number of actions that can be performed in a second. burst_count: How many actions that can be performed before being limited. """ - def __init__(self, rate_hz: float, burst_count: int): + def __init__(self, clock: Clock, rate_hz: float, burst_count: int): + self.clock = clock + self.rate_hz = rate_hz + self.burst_count = burst_count + # A ordered dictionary keeping track of actions, when they were last # performed and how often. Each entry is a mapping from a key of arbitrary type # to a tuple representing: # * How many times an action has occurred since a point in time - # * That point in time - self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int]] - self.rate_hz = rate_hz - self.burst_count = burst_count + # * The point in time + # * The rate_hz of this particular entry. This can vary per-request + self.actions = ( + OrderedDict() + ) # type: OrderedDict[Any, Tuple[float, int, Optional[float]]] def can_do_action( self, key: Any, - time_now_s: int, + time_now_s: Optional[int] = None, update: bool = True, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, @@ -50,7 +60,8 @@ class Ratelimiter(object): Args: key: The key we should use when rate limiting. Can be a user ID (when sending events), an IP address, etc. - time_now_s: The time now + time_now_s: The current time. Optional, defaults to the current time according + to self.clock. Pretty much only used for tests. update: Whether to count this check as performing the action rate_hz: The long term number of actions that can be performed in a second. Overrides the value set during instantiation if set. @@ -64,14 +75,15 @@ class Ratelimiter(object): -1 if a rate_hz has not been defined for this Ratelimiter """ # Override default values if set - rate_hz = rate_hz or self.rate_hz - burst_count = burst_count or self.burst_count + time_now_s = time_now_s if time_now_s is not None else self.clock.time() + rate_hz = rate_hz if rate_hz is not None else self.rate_hz + burst_count = burst_count if burst_count is not None else self.burst_count # Remove any expired entries - self._prune_message_counts(time_now_s, rate_hz) + self._prune_message_counts(time_now_s) # Check if there is an existing count entry for this key - action_count, time_start, = self.actions.get(key, (0.0, time_now_s)) + action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, None)) # Check whether performing another action is allowed time_delta = time_now_s - time_start @@ -90,7 +102,10 @@ class Ratelimiter(object): action_count += 1.0 if update: - self.actions[key] = (action_count, time_start) + self.actions[key] = (action_count, time_start, rate_hz) + + logger.info("rate and burst: %s %s. performed_count: %s, allowed: %s", rate_hz, + burst_count, performed_count, allowed) # Figure out the time when an action can be performed again if self.rate_hz > 0: @@ -105,18 +120,17 @@ class Ratelimiter(object): return allowed, time_allowed - def _prune_message_counts(self, time_now_s: int, rate_hz: float): + def _prune_message_counts(self, time_now_s: int): """Remove message count entries that have not exceeded their defined rate_hz limit Args: time_now_s: The current time - rate_hz: The long term number of actions that can be performed in a second. """ # We create a copy of the key list here as the dictionary is modified during # the loop for key in list(self.actions.keys()): - action_count, time_start = self.actions[key] + action_count, time_start, rate_hz = self.actions[key] # Rate limit = "seconds since we started limiting this action" * rate_hz # If this limit has not been exceeded, wipe our record of this action @@ -129,7 +143,7 @@ class Ratelimiter(object): def ratelimit( self, key: Any, - time_now_s: int, + time_now_s: Optional[int] = None, update: bool = True, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, @@ -138,7 +152,8 @@ class Ratelimiter(object): Args: key: An arbitrary key used to classify an action - time_now_s: The current time + time_now_s: The current time. Optional, defaults to the current time according + to self.clock. Pretty much only used for tests. update: Whether to count this check as performing the action rate_hz: The long term number of actions that can be performed in a second. Overrides the value set during instantiation if set. @@ -150,8 +165,9 @@ class Ratelimiter(object): milliseconds until the action can be performed again """ # Override default values if set - rate_hz = rate_hz or self.rate_hz - burst_count = burst_count or self.burst_count + time_now_s = time_now_s if time_now_s is not None else self.clock.time() + rate_hz = rate_hz if rate_hz is not None else self.rate_hz + burst_count = burst_count if burst_count is not None else self.burst_count allowed, time_allowed = self.can_do_action( key, time_now_s, update=update, rate_hz=rate_hz, burst_count=burst_count diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index e10e2427c4..0209bfe902 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -20,6 +20,7 @@ from twisted.internet import defer import synapse.types from synapse.api.constants import EventTypes, Membership from synapse.types import UserID +from synapse.api.ratelimiting import Ratelimiter logger = logging.getLogger(__name__) @@ -46,11 +47,20 @@ class BaseHandler(object): self.clock = hs.get_clock() self.hs = hs - self.request_ratelimiter = hs.get_request_ratelimiter() + # The rate_hz and burst_count are overridden on a per-user basis + self.request_ratelimiter = Ratelimiter(clock=self.clock, rate_hz=0, burst_count=0) self._rc_message = self.hs.config.rc_message - # If special admin redaction ratelimiting is disabled, this will be None - self.admin_redaction_ratelimiter = hs.get_admin_redaction_ratelimiter() + # Check whether ratelimiting room admin message redaction is enabled + # by the presence of rate limits in the config + if self.hs.config.rc_admin_redaction: + self.admin_redaction_ratelimiter = Ratelimiter( + clock=self.clock, + rate_hz=self.hs.config.rc_admin_redaction.per_second, + burst_count=self.hs.config.rc_admin_redaction.burst_count, + ) + else: + self.admin_redaction_ratelimiter = None self.server_name = hs.hostname diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 089c94f8b6..8934911661 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -110,6 +110,7 @@ class AuthHandler(BaseHandler): # as per `rc_login.failed_attempts`. # XXX: Should this be hs.get_login_failed_attempts_ratelimiter? self._failed_uia_attempts_ratelimiter = Ratelimiter( + clock=self.clock, rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, ) @@ -200,9 +201,7 @@ class AuthHandler(BaseHandler): user_id = requester.user.to_string() # Check if we should be ratelimited due to too many previous failed attempts - self._failed_uia_attempts_ratelimiter.ratelimit( - user_id, time_now_s=self._clock.time(), update=False, - ) + self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False) # build a list of supported flows flows = [[login_type] for login_type in self._supported_ui_auth_types] @@ -212,10 +211,8 @@ class AuthHandler(BaseHandler): flows, request, request_body, clientip, description ) except LoginError: - # Update the ratelimite to say we failed (`can_do_action` doesn't raise). - self._failed_uia_attempts_ratelimiter.can_do_action( - user_id, time_now_s=self._clock.time(), update=True, - ) + # Update the ratelimiter to say we failed (`can_do_action` doesn't raise). + self._failed_uia_attempts_ratelimiter.can_do_action(user_id) raise # find the completed login type diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ce18b33a63..1b14b9b798 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -427,9 +427,7 @@ class RegistrationHandler(BaseHandler): time_now = self.clock.time() - self.ratelimiter.ratelimit( - address, time_now_s=time_now, - ) + self.ratelimiter.ratelimit(address) def register_with_store( self, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 19c392849a..9d674af9d2 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -27,6 +27,7 @@ from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder from synapse.types import UserID from synapse.util.msisdn import phone_number_to_msisdn +from synapse.api.ratelimiting import Ratelimiter logger = logging.getLogger(__name__) @@ -86,10 +87,28 @@ class LoginRestServlet(RestServlet): self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.handlers = hs.get_handlers() - self._clock = hs.get_clock() self._well_known_builder = WellKnownBuilder(hs) - self._account_ratelimiter = hs.get_login_ratelimiter() - self._failed_attempts_ratelimiter = hs.get_login_failed_attempts_ratelimiter() + self._address_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_address.per_second, + burst_count=self.hs.config.rc_login_address.burst_count, + ) + self._account_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_account.per_second, + burst_count=self.hs.config.rc_login_account.burst_count, + ) + print( + "Creating fail ratelimiter: %s %s" % ( + self.hs.config.rc_login_failed_attempts.per_second, + self.hs.config.rc_login_failed_attempts.burst_count, + ), + ) + self._failed_attempts_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.rc_login_failed_attempts.burst_count, + ) def on_GET(self, request): flows = [] @@ -127,9 +146,7 @@ class LoginRestServlet(RestServlet): return 200, {} async def on_POST(self, request): - self._account_ratelimiter.ratelimit( - request.getClientIP(), time_now_s=self.hs.clock.time(), update=True, - ) + self._address_ratelimiter.ratelimit(request.getClientIP()) login_submission = parse_json_object_from_request(request) try: @@ -197,9 +214,7 @@ class LoginRestServlet(RestServlet): # We also apply account rate limiting using the 3PID as a key, as # otherwise using 3PID bypasses the ratelimiting based on user ID. - self._failed_attempts_ratelimiter.ratelimit( - (medium, address), time_now_s=self._clock.time(), update=False, - ) + self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False) # Check for login providers that support 3pid login types ( @@ -233,9 +248,7 @@ class LoginRestServlet(RestServlet): # If it returned None but the 3PID was bound then we won't hit # this code path, which is fine as then the per-user ratelimit # will kick in below. - self._failed_attempts_ratelimiter.can_do_action( - (medium, address), time_now_s=self._clock.time(), update=True, - ) + self._failed_attempts_ratelimiter.can_do_action((medium, address)) raise LoginError(403, "", errcode=Codes.FORBIDDEN) identifier = {"type": "m.id.user", "user": user_id} @@ -253,9 +266,7 @@ class LoginRestServlet(RestServlet): qualified_user_id = UserID(identifier["user"], self.hs.hostname).to_string() # Check if we've hit the failed ratelimit (but don't update it) - self._failed_attempts_ratelimiter.ratelimit( - qualified_user_id.lower(), time_now_s=self._clock.time(), update=False, - ) + self._failed_attempts_ratelimiter.ratelimit(qualified_user_id.lower(), update=False) try: canonical_user_id, callback = await self.auth_handler.validate_login( @@ -266,9 +277,7 @@ class LoginRestServlet(RestServlet): # limiter. Using `can_do_action` avoids us raising a ratelimit # exception and masking the LoginError. The actual ratelimiting # should have happened above. - self._failed_attempts_ratelimiter.can_do_action( - qualified_user_id.lower(), time_now_s=self._clock.time(), update=True, - ) + self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower()) raise result = await self._complete_login( @@ -301,9 +310,7 @@ class LoginRestServlet(RestServlet): # Before we actually log them in we check if they've already logged in # too often. This happens here rather than before as we don't # necessarily know the user before now. - self._account_ratelimiter.ratelimit( - user_id.lower(), time_now_s=self._clock.time(), update=True, - ) + self._account_ratelimiter.ratelimit(user_id.lower()) if create_non_existant_users: user_id = await self.auth_handler.check_user_exists(user_id) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 8567cbcab3..380d75d7ce 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -396,16 +396,7 @@ class RegisterRestServlet(RestServlet): client_addr = request.getClientIP() - time_now = self.clock.time() - - allowed, time_allowed = self.ratelimiter.can_do_action( - client_addr, time_now_s=time_now, update=False, - ) - - if not allowed: - raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)) - ) + self.ratelimiter.ratelimit(client_addr, update=False) kind = b"user" if b"kind" in request.args: diff --git a/synapse/server.py b/synapse/server.py index fc39b57135..1f1e6d9ff2 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -243,28 +243,12 @@ class HomeServer(object): self.clock = Clock(reactor) self.distributor = Distributor() - # The rate_hz and burst_count is overridden on a per-user basis - self.request_ratelimiter = Ratelimiter(rate_hz=0, burst_count=0,) - if config.rc_admin_redaction: - self.admin_redaction_ratelimiter = Ratelimiter( - rate_hz=config.rc_admin_redaction.per_second, - burst_count=config.rc_admin_redaction.burst_count, - ) - else: - self.admin_redaction_ratelimiter = None self.registration_ratelimiter = Ratelimiter( + clock=self.clock, rate_hz=config.rc_registration.per_second, burst_count=config.rc_registration.burst_count, ) - self.login_ratelimiter = Ratelimiter( - rate_hz=config.rc_login_account.per_second, - burst_count=config.rc_login_account.burst_count, - ) - self.login_failed_attempts_ratelimiter = Ratelimiter( - rate_hz=config.rc_login_failed_attempts.per_second, - burst_count=config.rc_login_failed_attempts.burst_count, - ) self.datastores = None @@ -334,21 +318,9 @@ class HomeServer(object): def get_distributor(self): return self.distributor - def get_request_ratelimiter(self) -> Ratelimiter: - return self.request_ratelimiter - def get_registration_ratelimiter(self) -> Ratelimiter: return self.registration_ratelimiter - def get_admin_redaction_ratelimiter(self) -> Optional[Ratelimiter]: - return self.admin_redaction_ratelimiter - - def get_login_ratelimiter(self) -> Ratelimiter: - return self.login_ratelimiter - - def get_login_failed_attempts_ratelimiter(self) -> Ratelimiter: - return self.login_failed_attempts_ratelimiter - def build_federation_client(self): return FederationClient(self) diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 973c7e007c..12425b1faa 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -5,7 +5,7 @@ from tests import unittest class TestRatelimiter(unittest.TestCase): def test_allowed(self): - limiter = Ratelimiter(rate_hz=0.1, burst_count=1) + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) allowed, time_allowed = limiter.can_do_action(key="test_id", time_now_s=0) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) @@ -19,7 +19,7 @@ class TestRatelimiter(unittest.TestCase): self.assertEquals(20.0, time_allowed) def test_pruning(self): - limiter = Ratelimiter(rate_hz=0.1, burst_count=1) + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) _, _ = limiter.can_do_action(key="test_id_1", time_now_s=0) self.assertIn("test_id_1", limiter.actions) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 891c986fbc..5af3db2cd5 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -14,12 +14,13 @@ # limitations under the License. -from mock import Mock, NonCallableMock +from mock import Mock, patch from twisted.internet import defer import synapse.types from synapse.api.errors import AuthError, SynapseError +from synapse.api.ratelimiting import Ratelimiter from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID @@ -55,17 +56,15 @@ class ProfileTestCase(unittest.TestCase): federation_client=self.mock_federation, federation_server=Mock(), federation_registry=self.mock_registry, - request_ratelimiter=NonCallableMock( - spec_set=["can_do_action", "ratelimit"] - ), - login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), ) - self.request_ratelimiter = hs.get_request_ratelimiter() - self.request_ratelimiter.can_do_action.return_value = (True, 0) - - self.login_ratelimiter = hs.get_login_ratelimiter() - self.login_ratelimiter.can_do_action.return_value = (True, 0) + # Patch Ratelimiter to allow all requests + patch.object( + Ratelimiter, "can_do_action", new_callable=lambda *args, **kwargs: (True, 0.0) + ) + patch.object( + Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None + ) self.store = hs.get_datastore() diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 49d22d9487..8b9dc57b44 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -13,8 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock, NonCallableMock +from mock import Mock, patch +from synapse.api.ratelimiting import Ratelimiter from tests.replication._base import BaseStreamTestCase @@ -23,18 +24,15 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase): hs = self.setup_test_homeserver( federation_client=Mock(), - request_ratelimiter=NonCallableMock( - spec_set=["can_do_action", "ratelimit"] - ), - login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), ) - # Prevent ratelimiting - self.request_ratelimiter = hs.get_request_ratelimiter() - self.request_ratelimiter.can_do_action.return_value = (True, 0) - - self.login_ratelimiter = hs.get_login_ratelimiter() - self.login_ratelimiter.can_do_action.return_value = (True, 0) + # Patch Ratelimiter to allow all requests + patch.object( + Ratelimiter, "can_do_action", new_callable=lambda *args, **kwargs: (True, 0.0) + ) + patch.object( + Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None + ) return hs diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 1ceba01494..1b946388a6 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -15,10 +15,11 @@ """ Tests REST events for /events paths.""" -from mock import Mock, NonCallableMock +from mock import Mock, patch import synapse.rest.admin from synapse.rest.client.v1 import events, login, room +from synapse.api.ratelimiting import Ratelimiter from tests import unittest @@ -42,17 +43,15 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver( config=config, - request_ratelimiter=NonCallableMock( - # rate_hz and burst_count are overridden in BaseHandler - spec_set=["can_do_action", "ratelimit", "rate_hz", "burst_count"] - ), - login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), ) - self.request_ratelimiter = hs.get_request_ratelimiter() - self.request_ratelimiter.can_do_action.return_value = (True, 0) - self.login_ratelimiter = hs.get_login_ratelimiter() - self.login_ratelimiter.can_do_action.return_value = (True, 0) + # Patch Ratelimiter to allow all requests + patch.object( + Ratelimiter, "can_do_action", new_callable=lambda *args, **kwargs: (True, 0.0) + ) + patch.object( + Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None + ) hs.get_handlers().federation_handler = Mock() diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index c01738ed69..a9952f53a6 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -7,6 +7,7 @@ import synapse.rest.admin from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import devices from synapse.rest.client.v2_alpha.account import WhoamiRestServlet +from synapse.api.ratelimiting import Ratelimiter from tests import unittest from tests.unittest import override_config @@ -26,7 +27,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() self.hs.config.enable_registration = True self.hs.config.registrations_require_3pid = [] @@ -35,10 +35,17 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): return self.hs + @override_config( + { + "rc_login": { + "account": { + "per_second": 0.17, + "burst_count": 5, + } + } + } + ) def test_POST_ratelimiting_per_address(self): - self.hs.get_login_ratelimiter().burst_count = 5 - self.hs.get_login_ratelimiter().rate_hz = 0.17 - # Create different users so we're sure not to be bothered by the per-user # ratelimiter. for i in range(0, 6): @@ -77,10 +84,17 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config( + { + "rc_login": { + "account": { + "per_second": 0.17, + "burst_count": 5, + } + } + } + ) def test_POST_ratelimiting_per_account(self): - self.hs.get_login_ratelimiter().burst_count = 5 - self.hs.get_login_ratelimiter().rate_hz = 0.17 - self.register_user("kermit", "monkey") for i in range(0, 6): @@ -116,10 +130,23 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config( + { + "rc_login": { + # Prevent the generic login ratelimiter from raising first + "address": { + "per_second": 1000, + "burst_count": 1000, + }, + "failed_attempts": { + "per_second": 0.17, + "burst_count": 5, + } + } + } + ) + @unittest.DEBUG def test_POST_ratelimiting_per_account_failed_attempts(self): - self.hs.get_login_failed_attempts_ratelimiter().burst_count = 5 - self.hs.get_login_failed_attempts_ratelimiter().rate_hz = 0.17 - self.register_user("kermit", "monkey") for i in range(0, 6): @@ -128,8 +155,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } - request_data = json.dumps(params) - request, channel = self.make_request(b"POST", LOGIN_URL, request_data) + request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) if i == 5: @@ -149,7 +175,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } - request_data = json.dumps(params) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index ba10f34468..8b19bcef60 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -20,12 +20,13 @@ import json -from mock import Mock, NonCallableMock +from mock import Mock, patch from six.moves.urllib import parse as urlparse from twisted.internet import defer import synapse.rest.admin +from synapse.api.ratelimiting import Ratelimiter from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.handlers.pagination import PurgeStatus from synapse.rest.client.v1 import directory, login, profile, room @@ -49,16 +50,15 @@ class RoomBase(unittest.HomeserverTestCase): "red", http_client=None, federation_client=Mock(), - request_ratelimiter=NonCallableMock( - spec_set=["can_do_action", "ratelimit"] - ), - login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), ) - self.request_ratelimiter = self.hs.get_request_ratelimiter() - self.request_ratelimiter.can_do_action.return_value = (True, 0) - self.login_ratelimiter = self.hs.get_login_ratelimiter() - self.login_ratelimiter.can_do_action.return_value = (True, 0) + # Patch Ratelimiter to allow all requests + patch.object( + Ratelimiter, "can_do_action", new_callable=lambda *args, **kwargs: (True, 0.0) + ) + patch.object( + Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None + ) self.hs.get_federation_handler = Mock(return_value=Mock()) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 2ec678a2a2..f57d2f3356 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -16,12 +16,13 @@ """Tests REST events for /rooms paths.""" -from mock import Mock, NonCallableMock +from mock import Mock, NonCallableMock, patch from twisted.internet import defer from synapse.rest.client.v1 import room from synapse.types import UserID +from synapse.api.ratelimiting import Ratelimiter from tests import unittest @@ -42,20 +43,18 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): "red", http_client=None, federation_client=Mock(), - request_ratelimiter=NonCallableMock( - spec_set=["can_do_action", "ratelimit"] - ), - login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), + ) + + # Patch Ratelimiter to allow all requests + patch.object( + Ratelimiter, "can_do_action", new_callable=lambda *args, **kwargs: (True, 0.0) + ) + patch.object( + Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None ) self.event_source = hs.get_event_sources().sources["typing"] - self.request_ratelimiter = hs.get_request_ratelimiter() - self.request_ratelimiter.can_do_action.return_value = (True, 0) - - self.login_ratelimiter = hs.get_login_ratelimiter() - self.login_ratelimiter.can_do_action.return_value = (True, 0) - hs.get_handlers().federation_handler = Mock() def get_user_by_access_token(token=None, allow_guest=False):