Fix WheelTimer implementation that can expired timeout early (#17850)

When entries insert in the end of timer queue, then unnecessary entry
inserted (with duplicated key).
This can lead to some timeouts expired early and consume memory.
This commit is contained in:
Alexander Udovichenko 2024-11-05 21:08:17 +03:00 committed by GitHub
parent 361bdafb87
commit 211c31dbd7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 29 additions and 28 deletions

1
changelog.d/17850.bugfix Normal file
View file

@ -0,0 +1 @@
Fix bug when some presence and typing timeouts can expire early.

View file

@ -47,7 +47,6 @@ class WheelTimer(Generic[T]):
""" """
self.bucket_size: int = bucket_size self.bucket_size: int = bucket_size
self.entries: List[_Entry[T]] = [] self.entries: List[_Entry[T]] = []
self.current_tick: int = 0
def insert(self, now: int, obj: T, then: int) -> None: def insert(self, now: int, obj: T, then: int) -> None:
"""Inserts object into timer. """Inserts object into timer.
@ -78,11 +77,10 @@ class WheelTimer(Generic[T]):
self.entries[max(min_key, then_key) - min_key].elements.add(obj) self.entries[max(min_key, then_key) - min_key].elements.add(obj)
return return
next_key = now_key + 1
if self.entries: if self.entries:
last_key = self.entries[-1].end_key last_key = self.entries[-1].end_key + 1
else: else:
last_key = next_key last_key = now_key + 1
# Handle the case when `then` is in the past and `entries` is empty. # Handle the case when `then` is in the past and `entries` is empty.
then_key = max(last_key, then_key) then_key = max(last_key, then_key)

View file

@ -28,53 +28,55 @@ class WheelTimerTestCase(unittest.TestCase):
def test_single_insert_fetch(self) -> None: def test_single_insert_fetch(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5) wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
obj = object() wheel.insert(100, "1", 150)
wheel.insert(100, obj, 150)
self.assertListEqual(wheel.fetch(101), []) self.assertListEqual(wheel.fetch(101), [])
self.assertListEqual(wheel.fetch(110), []) self.assertListEqual(wheel.fetch(110), [])
self.assertListEqual(wheel.fetch(120), []) self.assertListEqual(wheel.fetch(120), [])
self.assertListEqual(wheel.fetch(130), []) self.assertListEqual(wheel.fetch(130), [])
self.assertListEqual(wheel.fetch(149), []) self.assertListEqual(wheel.fetch(149), [])
self.assertListEqual(wheel.fetch(156), [obj]) self.assertListEqual(wheel.fetch(156), ["1"])
self.assertListEqual(wheel.fetch(170), []) self.assertListEqual(wheel.fetch(170), [])
def test_multi_insert(self) -> None: def test_multi_insert(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5) wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
obj1 = object() wheel.insert(100, "1", 150)
obj2 = object() wheel.insert(105, "2", 130)
obj3 = object() wheel.insert(106, "3", 160)
wheel.insert(100, obj1, 150)
wheel.insert(105, obj2, 130)
wheel.insert(106, obj3, 160)
self.assertListEqual(wheel.fetch(110), []) self.assertListEqual(wheel.fetch(110), [])
self.assertListEqual(wheel.fetch(135), [obj2]) self.assertListEqual(wheel.fetch(135), ["2"])
self.assertListEqual(wheel.fetch(149), []) self.assertListEqual(wheel.fetch(149), [])
self.assertListEqual(wheel.fetch(158), [obj1]) self.assertListEqual(wheel.fetch(158), ["1"])
self.assertListEqual(wheel.fetch(160), []) self.assertListEqual(wheel.fetch(160), [])
self.assertListEqual(wheel.fetch(200), [obj3]) self.assertListEqual(wheel.fetch(200), ["3"])
self.assertListEqual(wheel.fetch(210), []) self.assertListEqual(wheel.fetch(210), [])
def test_insert_past(self) -> None: def test_insert_past(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5) wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
obj = object() wheel.insert(100, "1", 50)
wheel.insert(100, obj, 50) self.assertListEqual(wheel.fetch(120), ["1"])
self.assertListEqual(wheel.fetch(120), [obj])
def test_insert_past_multi(self) -> None: def test_insert_past_multi(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5) wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
obj1 = object() wheel.insert(100, "1", 150)
obj2 = object() wheel.insert(100, "2", 140)
obj3 = object() wheel.insert(100, "3", 50)
wheel.insert(100, obj1, 150) self.assertListEqual(wheel.fetch(110), ["3"])
wheel.insert(100, obj2, 140)
wheel.insert(100, obj3, 50)
self.assertListEqual(wheel.fetch(110), [obj3])
self.assertListEqual(wheel.fetch(120), []) self.assertListEqual(wheel.fetch(120), [])
self.assertListEqual(wheel.fetch(147), [obj2]) self.assertListEqual(wheel.fetch(147), ["2"])
self.assertListEqual(wheel.fetch(200), [obj1]) self.assertListEqual(wheel.fetch(200), ["1"])
self.assertListEqual(wheel.fetch(240), []) self.assertListEqual(wheel.fetch(240), [])
def test_multi_insert_then_past(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
wheel.insert(100, "1", 150)
wheel.insert(100, "2", 160)
wheel.insert(100, "3", 155)
self.assertListEqual(wheel.fetch(110), [])
self.assertListEqual(wheel.fetch(158), ["1"])