Add a short sleep if the request is rate-limited (#17210)

This helps prevent clients from "tight-looping" retrying their request.
This commit is contained in:
Erik Johnston 2024-05-18 12:03:30 +01:00 committed by GitHub
parent 38f03a09ff
commit 52af16c561
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 15 additions and 4 deletions

1
changelog.d/17210.misc Normal file
View file

@ -0,0 +1 @@
Add a short pause when rate-limiting a request.

View file

@ -316,6 +316,10 @@ class Ratelimiter:
) )
if not allowed: if not allowed:
# We pause for a bit here to stop clients from "tight-looping" on
# retrying their request.
await self.clock.sleep(0.5)
raise LimitExceededError( raise LimitExceededError(
limiter_name=self._limiter_name, limiter_name=self._limiter_name,
retry_after_ms=int(1000 * (time_allowed - time_now_s)), retry_after_ms=int(1000 * (time_allowed - time_now_s)),

View file

@ -116,8 +116,9 @@ class TestRatelimiter(unittest.HomeserverTestCase):
# Should raise # Should raise
with self.assertRaises(LimitExceededError) as context: with self.assertRaises(LimitExceededError) as context:
self.get_success_or_raise( self.get_success_or_raise(
limiter.ratelimit(None, key="test_id", _time_now_s=5) limiter.ratelimit(None, key="test_id", _time_now_s=5), by=0.5
) )
self.assertEqual(context.exception.retry_after_ms, 5000) self.assertEqual(context.exception.retry_after_ms, 5000)
# Shouldn't raise # Shouldn't raise
@ -192,7 +193,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
# Second attempt, 1s later, will fail # Second attempt, 1s later, will fail
with self.assertRaises(LimitExceededError) as context: with self.assertRaises(LimitExceededError) as context:
self.get_success_or_raise( self.get_success_or_raise(
limiter.ratelimit(None, key=("test_id",), _time_now_s=1) limiter.ratelimit(None, key=("test_id",), _time_now_s=1), by=0.5
) )
self.assertEqual(context.exception.retry_after_ms, 9000) self.assertEqual(context.exception.retry_after_ms, 9000)

View file

@ -483,6 +483,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
event.room_version, event.room_version,
), ),
exc=LimitExceededError, exc=LimitExceededError,
by=0.5,
) )
def _build_and_send_join_event( def _build_and_send_join_event(

View file

@ -70,6 +70,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
action=Membership.JOIN, action=Membership.JOIN,
), ),
LimitExceededError, LimitExceededError,
by=0.5,
) )
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}}) @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}})
@ -206,6 +207,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
remote_room_hosts=[self.OTHER_SERVER_NAME], remote_room_hosts=[self.OTHER_SERVER_NAME],
), ),
LimitExceededError, LimitExceededError,
by=0.5,
) )
# TODO: test that remote joins to a room are rate limited. # TODO: test that remote joins to a room are rate limited.
@ -273,6 +275,7 @@ class TestReplicatedJoinsLimitedByPerRoomRateLimiter(BaseMultiWorkerStreamTestCa
action=Membership.JOIN, action=Membership.JOIN,
), ),
LimitExceededError, LimitExceededError,
by=0.5,
) )
# Try to join as Chris on the original worker. Should get denied because Alice # Try to join as Chris on the original worker. Should get denied because Alice
@ -285,6 +288,7 @@ class TestReplicatedJoinsLimitedByPerRoomRateLimiter(BaseMultiWorkerStreamTestCa
action=Membership.JOIN, action=Membership.JOIN,
), ),
LimitExceededError, LimitExceededError,
by=0.5,
) )

View file

@ -637,13 +637,13 @@ class HomeserverTestCase(TestCase):
return self.successResultOf(deferred) return self.successResultOf(deferred)
def get_failure( def get_failure(
self, d: Awaitable[Any], exc: Type[_ExcType] self, d: Awaitable[Any], exc: Type[_ExcType], by: float = 0.0
) -> _TypedFailure[_ExcType]: ) -> _TypedFailure[_ExcType]:
""" """
Run a Deferred and get a Failure from it. The failure must be of the type `exc`. Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
""" """
deferred: Deferred[Any] = ensureDeferred(d) # type: ignore[arg-type] deferred: Deferred[Any] = ensureDeferred(d) # type: ignore[arg-type]
self.pump() self.pump(by)
return self.failureResultOf(deferred, exc) return self.failureResultOf(deferred, exc)
def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV: def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV: