From a308d99f30d7e660115e355c54c37ac149cdbe53 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Tue, 13 Aug 2024 12:27:42 -0500 Subject: [PATCH 1/3] Sliding Sync: Exclude partially stated rooms if we must await full state (#17538) Previously, we just had very basic partial room exclusion based on whether we were lazy-loading room members. Now with this PR, we added `must_await_full_state(...)` with rules to check if we have a we're only requesting `required_state` which is completely satisfied even with partial state. Partially-stated rooms should have all state events except for remote membership events so if we require a remote membership event anywhere, then we need to return `True`. --- changelog.d/17538.bugfix | 1 + synapse/handlers/sliding_sync.py | 104 +++++++-- .../sliding_sync/test_rooms_required_state.py | 217 ++++++++++++++---- 3 files changed, 266 insertions(+), 56 deletions(-) create mode 100644 changelog.d/17538.bugfix diff --git a/changelog.d/17538.bugfix b/changelog.d/17538.bugfix new file mode 100644 index 0000000000..9e4e31dbdb --- /dev/null +++ b/changelog.d/17538.bugfix @@ -0,0 +1 @@ +Better exclude partially stated rooms if we must await full state in experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint. diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 18a96843be..99510254f3 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -24,6 +24,7 @@ from itertools import chain from typing import ( TYPE_CHECKING, Any, + Callable, Dict, Final, List, @@ -366,6 +367,73 @@ class RoomSyncConfig: else: self.required_state_map[state_type].add(state_key) + def must_await_full_state( + self, + is_mine_id: Callable[[str], bool], + ) -> bool: + """ + Check if we have a we're only requesting `required_state` which is completely + satisfied even with partial state, then we don't need to `await_full_state` before + we can return it. + + Also see `StateFilter.must_await_full_state(...)` for comparison + + Partially-stated rooms should have all state events except for remote membership + events so if we require a remote membership event anywhere, then we need to + return `True` (requires full state). + + Args: + is_mine_id: a callable which confirms if a given state_key matches a mxid + of a local user + """ + wildcard_state_keys = self.required_state_map.get(StateValues.WILDCARD) + # Requesting *all* state in the room so we have to wait + if ( + wildcard_state_keys is not None + and StateValues.WILDCARD in wildcard_state_keys + ): + return True + + # If the wildcards don't refer to remote user IDs, then we don't need to wait + # for full state. + if wildcard_state_keys is not None: + for possible_user_id in wildcard_state_keys: + if not possible_user_id[0].startswith(UserID.SIGIL): + # Not a user ID + continue + + localpart_hostname = possible_user_id.split(":", 1) + if len(localpart_hostname) < 2: + # Not a user ID + continue + + if not is_mine_id(possible_user_id): + return True + + membership_state_keys = self.required_state_map.get(EventTypes.Member) + # We aren't requesting any membership events at all so the partial state will + # cover us. + if membership_state_keys is None: + return False + + # If we're requesting entirely local users, the partial state will cover us. + for user_id in membership_state_keys: + if user_id == StateValues.ME: + continue + # We're lazy-loading membership so we can just return the state we have. + # Lazy-loading means we include membership for any event `sender` in the + # timeline but since we had to auth those timeline events, we will have the + # membership state for them (including from remote senders). + elif user_id == StateValues.LAZY: + continue + elif user_id == StateValues.WILDCARD: + return False + elif not is_mine_id(user_id): + return True + + # Local users only so the partial state will cover us. + return False + class StateValues: """ @@ -395,6 +463,7 @@ class SlidingSyncHandler: self.device_handler = hs.get_device_handler() self.push_rules_handler = hs.get_push_rules_handler() self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync + self.is_mine_id = hs.is_mine_id self.connection_store = SlidingSyncConnectionStore() @@ -575,19 +644,10 @@ class SlidingSyncHandler: # Since creating the `RoomSyncConfig` takes some work, let's just do it # once and make a copy whenever we need it. room_sync_config = RoomSyncConfig.from_room_config(list_config) - membership_state_keys = room_sync_config.required_state_map.get( - EventTypes.Member - ) - # Also see `StateFilter.must_await_full_state(...)` for comparison - lazy_loading = ( - membership_state_keys is not None - and StateValues.LAZY in membership_state_keys - ) - if not lazy_loading: - # Exclude partially-stated rooms unless the `required_state` - # only has `["m.room.member", "$LAZY"]` for membership - # (lazy-loading room members). + # Exclude partially-stated rooms if we must wait for the room to be + # fully-stated + if room_sync_config.must_await_full_state(self.is_mine_id): filtered_sync_room_map = { room_id: room for room_id, room in filtered_sync_room_map.items() @@ -654,6 +714,12 @@ class SlidingSyncHandler: # Handle room subscriptions if has_room_subscriptions and sync_config.room_subscriptions is not None: with start_active_span("assemble_room_subscriptions"): + # Find which rooms are partially stated and may need to be filtered out + # depending on the `required_state` requested (see below). + partial_state_room_map = await self.store.is_partial_state_room_batched( + sync_config.room_subscriptions.keys() + ) + for ( room_id, room_subscription, @@ -677,12 +743,20 @@ class SlidingSyncHandler: ) # Take the superset of the `RoomSyncConfig` for each room. - # - # Update our `relevant_room_map` with the room we're going to display - # and need to fetch more info about. room_sync_config = RoomSyncConfig.from_room_config( room_subscription ) + + # Exclude partially-stated rooms if we must wait for the room to be + # fully-stated + if room_sync_config.must_await_full_state(self.is_mine_id): + if partial_state_room_map.get(room_id): + continue + + all_rooms.add(room_id) + + # Update our `relevant_room_map` with the room we're going to display + # and need to fetch more info about. existing_room_sync_config = relevant_room_map.get(room_id) if existing_room_sync_config is not None: existing_room_sync_config.combine_room_sync_config( diff --git a/tests/rest/client/sliding_sync/test_rooms_required_state.py b/tests/rest/client/sliding_sync/test_rooms_required_state.py index a13cad223f..823e7db569 100644 --- a/tests/rest/client/sliding_sync/test_rooms_required_state.py +++ b/tests/rest/client/sliding_sync/test_rooms_required_state.py @@ -631,8 +631,7 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase): def test_rooms_required_state_partial_state(self) -> None: """ - Test partially-stated room are excluded unless `rooms.required_state` is - lazy-loading room members. + Test partially-stated room are excluded if they require full state. """ user1_id = self.register_user("user1", "pass") user1_tok = self.login(user1_id, "pass") @@ -649,59 +648,195 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase): mark_event_as_partial_state(self.hs, join_response2["event_id"], room_id2) ) - # Make the Sliding Sync request (NOT lazy-loading room members) + # Make the Sliding Sync request with examples where `must_await_full_state()` is + # `False` sync_body = { "lists": { - "foo-list": { + "no-state-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 0, + }, + "other-state-list": { "ranges": [[0, 1]], "required_state": [ [EventTypes.Create, ""], ], "timeline_limit": 0, }, + "lazy-load-list": { + "ranges": [[0, 1]], + "required_state": [ + [EventTypes.Create, ""], + # Lazy-load room members + [EventTypes.Member, StateValues.LAZY], + # Local member + [EventTypes.Member, user2_id], + ], + "timeline_limit": 0, + }, + "local-members-only-list": { + "ranges": [[0, 1]], + "required_state": [ + # Own user ID + [EventTypes.Member, user1_id], + # Local member + [EventTypes.Member, user2_id], + ], + "timeline_limit": 0, + }, + "me-list": { + "ranges": [[0, 1]], + "required_state": [ + # Own user ID + [EventTypes.Member, StateValues.ME], + # Local member + [EventTypes.Member, user2_id], + ], + "timeline_limit": 0, + }, + "wildcard-type-local-state-key-list": { + "ranges": [[0, 1]], + "required_state": [ + ["*", user1_id], + # Not a user ID + ["*", "foobarbaz"], + # Not a user ID + ["*", "foo.bar.baz"], + # Not a user ID + ["*", "@foo"], + ], + "timeline_limit": 0, + }, + } + } + response_body, _ = self.do_sync(sync_body, tok=user1_tok) + + # The list should include both rooms now because we don't need full state + for list_key in response_body["lists"].keys(): + self.assertIncludes( + set(response_body["lists"][list_key]["ops"][0]["room_ids"]), + {room_id2, room_id1}, + exact=True, + message=f"Expected all rooms to show up for list_key={list_key}. Response " + + str(response_body["lists"][list_key]), + ) + + # Take each of the list variants and apply them to room subscriptions to make + # sure the same rules apply + for list_key in sync_body["lists"].keys(): + sync_body_for_subscriptions = { + "room_subscriptions": { + room_id1: { + "required_state": sync_body["lists"][list_key][ + "required_state" + ], + "timeline_limit": 0, + }, + room_id2: { + "required_state": sync_body["lists"][list_key][ + "required_state" + ], + "timeline_limit": 0, + }, + } + } + response_body, _ = self.do_sync(sync_body_for_subscriptions, tok=user1_tok) + + self.assertIncludes( + set(response_body["rooms"].keys()), + {room_id2, room_id1}, + exact=True, + message=f"Expected all rooms to show up for test_key={list_key}.", + ) + + # ===================================================================== + + # Make the Sliding Sync request with examples where `must_await_full_state()` is + # `True` + sync_body = { + "lists": { + "wildcard-list": { + "ranges": [[0, 1]], + "required_state": [ + ["*", "*"], + ], + "timeline_limit": 0, + }, + "wildcard-type-remote-state-key-list": { + "ranges": [[0, 1]], + "required_state": [ + ["*", "@some:remote"], + # Not a user ID + ["*", "foobarbaz"], + # Not a user ID + ["*", "foo.bar.baz"], + # Not a user ID + ["*", "@foo"], + ], + "timeline_limit": 0, + }, + "remote-member-list": { + "ranges": [[0, 1]], + "required_state": [ + # Own user ID + [EventTypes.Member, user1_id], + # Remote member + [EventTypes.Member, "@some:remote"], + # Local member + [EventTypes.Member, user2_id], + ], + "timeline_limit": 0, + }, + "lazy-but-remote-member-list": { + "ranges": [[0, 1]], + "required_state": [ + # Lazy-load room members + [EventTypes.Member, StateValues.LAZY], + # Remote member + [EventTypes.Member, "@some:remote"], + ], + "timeline_limit": 0, + }, } } response_body, _ = self.do_sync(sync_body, tok=user1_tok) # Make sure the list includes room1 but room2 is excluded because it's still # partially-stated - self.assertListEqual( - list(response_body["lists"]["foo-list"]["ops"]), - [ - { - "op": "SYNC", - "range": [0, 1], - "room_ids": [room_id1], - } - ], - response_body["lists"]["foo-list"], - ) + for list_key in response_body["lists"].keys(): + self.assertIncludes( + set(response_body["lists"][list_key]["ops"][0]["room_ids"]), + {room_id1}, + exact=True, + message=f"Expected only fully-stated rooms to show up for list_key={list_key}. Response " + + str(response_body["lists"][list_key]), + ) - # Make the Sliding Sync request (with lazy-loading room members) - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [ - [EventTypes.Create, ""], - # Lazy-load room members - [EventTypes.Member, StateValues.LAZY], - ], - "timeline_limit": 0, - }, + # Take each of the list variants and apply them to room subscriptions to make + # sure the same rules apply + for list_key in sync_body["lists"].keys(): + sync_body_for_subscriptions = { + "room_subscriptions": { + room_id1: { + "required_state": sync_body["lists"][list_key][ + "required_state" + ], + "timeline_limit": 0, + }, + room_id2: { + "required_state": sync_body["lists"][list_key][ + "required_state" + ], + "timeline_limit": 0, + }, + } } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) + response_body, _ = self.do_sync(sync_body_for_subscriptions, tok=user1_tok) - # The list should include both rooms now because we're lazy-loading room members - self.assertListEqual( - list(response_body["lists"]["foo-list"]["ops"]), - [ - { - "op": "SYNC", - "range": [0, 1], - "room_ids": [room_id2, room_id1], - } - ], - response_body["lists"]["foo-list"], - ) + self.assertIncludes( + set(response_body["rooms"].keys()), + {room_id1}, + exact=True, + message=f"Expected only fully-stated rooms to show up for test_key={list_key}.", + ) From b05b2e14bbba0041e7818213b0885ec65540e617 Mon Sep 17 00:00:00 2001 From: Shay Date: Wed, 14 Aug 2024 01:49:01 -0700 Subject: [PATCH 2/3] Handle lower-case http headers in `_Mulitpart_Parser_Protocol` (#17545) --- changelog.d/17545.bugfix | 1 + synapse/http/client.py | 6 +++--- tests/http/test_client.py | 42 +++++++++++++++++++++++++++++++-------- 3 files changed, 38 insertions(+), 11 deletions(-) create mode 100644 changelog.d/17545.bugfix diff --git a/changelog.d/17545.bugfix b/changelog.d/17545.bugfix new file mode 100644 index 0000000000..31e22d873e --- /dev/null +++ b/changelog.d/17545.bugfix @@ -0,0 +1 @@ +Handle lower-case http headers in `_Mulitpart_Parser_Protocol`. \ No newline at end of file diff --git a/synapse/http/client.py b/synapse/http/client.py index daa5cc899b..cb4f72d771 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -1057,11 +1057,11 @@ class _MultipartParserProtocol(protocol.Protocol): if not self.parser: def on_header_field(data: bytes, start: int, end: int) -> None: - if data[start:end] == b"Location": + if data[start:end].lower() == b"location": self.has_redirect = True - if data[start:end] == b"Content-Disposition": + if data[start:end].lower() == b"content-disposition": self.in_disposition = True - if data[start:end] == b"Content-Type": + if data[start:end].lower() == b"content-type": self.in_content_type = True def on_header_value(data: bytes, start: int, end: int) -> None: diff --git a/tests/http/test_client.py b/tests/http/test_client.py index 721917f957..f2abec190b 100644 --- a/tests/http/test_client.py +++ b/tests/http/test_client.py @@ -49,8 +49,11 @@ from tests.unittest import TestCase class ReadMultipartResponseTests(TestCase): - data1 = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_" - data2 = b"to_stream\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n" + multipart_response_data1 = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_" + multipart_response_data2 = ( + b"to_stream\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n" + ) + multipart_response_data_cased = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\ncOntEnt-type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-tyPe: text/plain\r\nconTent-dispOsition: inline; filename=test_upload\r\n\r\nfile_" redirect_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nLocation: https://cdn.example.org/ab/c1/2345.txt\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n" @@ -103,8 +106,31 @@ class ReadMultipartResponseTests(TestCase): result, deferred, protocol = self._build_multipart_response(249, 250) # Start sending data. - protocol.dataReceived(self.data1) - protocol.dataReceived(self.data2) + protocol.dataReceived(self.multipart_response_data1) + protocol.dataReceived(self.multipart_response_data2) + # Close the connection. + protocol.connectionLost(Failure(ResponseDone())) + + multipart_response: MultipartResponse = deferred.result # type: ignore[assignment] + + self.assertEqual(multipart_response.json, b"{}") + self.assertEqual(result.getvalue(), b"file_to_stream") + self.assertEqual(multipart_response.length, len(b"file_to_stream")) + self.assertEqual(multipart_response.content_type, b"text/plain") + self.assertEqual( + multipart_response.disposition, b"inline; filename=test_upload" + ) + + def test_parse_file_lowercase_headers(self) -> None: + """ + Check that a multipart response containing a file is properly parsed + into the json/file parts, and the json and file are properly captured if the http headers are lowercased + """ + result, deferred, protocol = self._build_multipart_response(249, 250) + + # Start sending data. + protocol.dataReceived(self.multipart_response_data_cased) + protocol.dataReceived(self.multipart_response_data2) # Close the connection. protocol.connectionLost(Failure(ResponseDone())) @@ -143,7 +169,7 @@ class ReadMultipartResponseTests(TestCase): result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180) # Start sending data. - protocol.dataReceived(self.data1) + protocol.dataReceived(self.multipart_response_data1) self.assertEqual(result.getvalue(), b"file_") self._assert_error(deferred, protocol) @@ -154,11 +180,11 @@ class ReadMultipartResponseTests(TestCase): result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180) # Start sending data. - protocol.dataReceived(self.data1) + protocol.dataReceived(self.multipart_response_data1) self._assert_error(deferred, protocol) # More data might have come in. - protocol.dataReceived(self.data2) + protocol.dataReceived(self.multipart_response_data2) self.assertEqual(result.getvalue(), b"file_") self._assert_error(deferred, protocol) @@ -172,7 +198,7 @@ class ReadMultipartResponseTests(TestCase): self.assertFalse(deferred.called) # Start sending data. - protocol.dataReceived(self.data1) + protocol.dataReceived(self.multipart_response_data1) self._assert_error(deferred, protocol) self._cleanup_error(deferred) From a51daffba5e58489f93f76a074aa7d6f73533226 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 14 Aug 2024 12:41:53 +0100 Subject: [PATCH 3/3] Reduce concurrent thread usage in media (#17567) Follow on from #17558 Basically, we want to reduce the number of threads we want to use at a time, i.e. reduce the number of threads that are paused/blocked. We do this by returning from the thread when the consumer pauses the producer, rather than pausing in the thread. --------- Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/17567.misc | 1 + synapse/media/_base.py | 91 +++++++++++++++++++---------------- synapse/util/async_helpers.py | 43 +++++++++++++++++ 3 files changed, 93 insertions(+), 42 deletions(-) create mode 100644 changelog.d/17567.misc diff --git a/changelog.d/17567.misc b/changelog.d/17567.misc new file mode 100644 index 0000000000..cfa8089a81 --- /dev/null +++ b/changelog.d/17567.misc @@ -0,0 +1 @@ +Speed up responding to media requests. diff --git a/synapse/media/_base.py b/synapse/media/_base.py index 89dea39163..fdbbe29472 100644 --- a/synapse/media/_base.py +++ b/synapse/media/_base.py @@ -22,7 +22,6 @@ import logging import os -import threading import urllib from abc import ABC, abstractmethod from types import TracebackType @@ -56,6 +55,7 @@ from synapse.logging.context import ( run_in_background, ) from synapse.util import Clock +from synapse.util.async_helpers import DeferredEvent from synapse.util.stringutils import is_ascii if TYPE_CHECKING: @@ -620,10 +620,13 @@ class ThreadedFileSender: A producer that sends the contents of a file to a consumer, reading from the file on a thread. - This works by spawning a loop in a threadpool that repeatedly reads from the - file and sends it to the consumer. The main thread communicates with the - loop via two `threading.Event`, which controls when to start/pause reading - and when to terminate. + This works by having a loop in a threadpool repeatedly reading from the + file, until the consumer pauses the producer. There is then a loop in the + main thread that waits until the consumer resumes the producer and then + starts reading in the threadpool again. + + This is done to ensure that we're never waiting in the threadpool, as + otherwise its easy to starve it of threads. """ # How much data to read in one go. @@ -643,12 +646,11 @@ class ThreadedFileSender: # Signals if the thread should keep reading/sending data. Set means # continue, clear means pause. - self.wakeup_event = threading.Event() + self.wakeup_event = DeferredEvent(self.reactor) # Signals if the thread should terminate, e.g. because the consumer has - # gone away. Both this and `wakeup_event` should be set to terminate the - # loop (otherwise the thread will block on `wakeup_event`). - self.stop_event = threading.Event() + # gone away. + self.stop_writing = False def beginFileTransfer( self, file: BinaryIO, consumer: interfaces.IConsumer @@ -663,12 +665,7 @@ class ThreadedFileSender: # We set the wakeup signal as we should start producing immediately. self.wakeup_event.set() - run_in_background( - defer_to_threadpool, - self.reactor, - self.thread_pool, - self._on_thread_read_loop, - ) + run_in_background(self.start_read_loop) return make_deferred_yieldable(self.deferred) @@ -686,42 +683,52 @@ class ThreadedFileSender: # Unregister the consumer so we don't try and interact with it again. self.consumer = None - # Terminate the thread loop. + # Terminate the loop. + self.stop_writing = True self.wakeup_event.set() - self.stop_event.set() if not self.deferred.called: self.deferred.errback(Exception("Consumer asked us to stop producing")) - def _on_thread_read_loop(self) -> None: - """This is the loop that happens on a thread.""" - + async def start_read_loop(self) -> None: + """This is the loop that drives reading/writing""" try: - while not self.stop_event.is_set(): - # We wait for the producer to signal that the consumer wants - # more data (or we should abort) + while not self.stop_writing: + # Start the loop in the threadpool to read data. + more_data = await defer_to_threadpool( + self.reactor, self.thread_pool, self._on_thread_read_loop + ) + if not more_data: + # Reached EOF, we can just return. + return + if not self.wakeup_event.is_set(): - ret = self.wakeup_event.wait(self.TIMEOUT_SECONDS) + ret = await self.wakeup_event.wait(self.TIMEOUT_SECONDS) if not ret: raise Exception("Timed out waiting to resume") - - # Check if we were woken up so that we abort the download - if self.stop_event.is_set(): - return - - # The file should always have been set before we get here. - assert self.file is not None - - chunk = self.file.read(self.CHUNK_SIZE) - if not chunk: - return - - self.reactor.callFromThread(self._write, chunk) - except Exception: - self.reactor.callFromThread(self._error, Failure()) + self._error(Failure()) finally: - self.reactor.callFromThread(self._finish) + self._finish() + + def _on_thread_read_loop(self) -> bool: + """This is the loop that happens on a thread. + + Returns: + Whether there is more data to send. + """ + + while not self.stop_writing and self.wakeup_event.is_set(): + # The file should always have been set before we get here. + assert self.file is not None + + chunk = self.file.read(self.CHUNK_SIZE) + if not chunk: + return False + + self.reactor.callFromThread(self._write, chunk) + + return True def _write(self, chunk: bytes) -> None: """Called from the thread to write a chunk of data""" @@ -729,7 +736,7 @@ class ThreadedFileSender: self.consumer.write(chunk) def _error(self, failure: Failure) -> None: - """Called from the thread when there was a fatal error""" + """Called when there was a fatal error""" if self.consumer: self.consumer.unregisterProducer() self.consumer = None @@ -738,7 +745,7 @@ class ThreadedFileSender: self.deferred.errback(failure) def _finish(self) -> None: - """Called from the thread when it finishes (either on success or + """Called when we have finished writing (either on success or failure).""" if self.file: self.file.close() diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 70139beef2..8618bb0651 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -885,3 +885,46 @@ class AwakenableSleeper: # Cancel the sleep if we were woken up if call.active(): call.cancel() + + +class DeferredEvent: + """Like threading.Event but for async code""" + + def __init__(self, reactor: IReactorTime) -> None: + self._reactor = reactor + self._deferred: "defer.Deferred[None]" = defer.Deferred() + + def set(self) -> None: + if not self._deferred.called: + self._deferred.callback(None) + + def clear(self) -> None: + if self._deferred.called: + self._deferred = defer.Deferred() + + def is_set(self) -> bool: + return self._deferred.called + + async def wait(self, timeout_seconds: float) -> bool: + if self.is_set(): + return True + + # Create a deferred that gets called in N seconds + sleep_deferred: "defer.Deferred[None]" = defer.Deferred() + call = self._reactor.callLater(timeout_seconds, sleep_deferred.callback, None) + + try: + await make_deferred_yieldable( + defer.DeferredList( + [sleep_deferred, self._deferred], + fireOnOneCallback=True, + fireOnOneErrback=True, + consumeErrors=True, + ) + ) + finally: + # Cancel the sleep if we were woken up + if call.active(): + call.cancel() + + return self.is_set()