Speed up sliding sync by computing extensions in parallel (#17884)

The main change here is to add a helper function
`gather_optional_coroutines`, which works in a similar way as
`yieldable_gather_results` but takes a set of coroutines rather than a
function
This commit is contained in:
Erik Johnston 2024-10-30 10:51:04 +00:00 committed by GitHub
parent 58deef5eba
commit 83513b75f7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 285 additions and 13 deletions

1
changelog.d/17884.misc Normal file
View file

@ -0,0 +1 @@
Minor speed-up of sliding sync by computing extensions results in parallel.

View file

@ -49,7 +49,10 @@ from synapse.types.handlers.sliding_sync import (
SlidingSyncConfig,
SlidingSyncResult,
)
from synapse.util.async_helpers import concurrently_execute
from synapse.util.async_helpers import (
concurrently_execute,
gather_optional_coroutines,
)
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -97,26 +100,26 @@ class SlidingSyncExtensionHandler:
if sync_config.extensions is None:
return SlidingSyncResult.Extensions()
to_device_response = None
to_device_coro = None
if sync_config.extensions.to_device is not None:
to_device_response = await self.get_to_device_extension_response(
to_device_coro = self.get_to_device_extension_response(
sync_config=sync_config,
to_device_request=sync_config.extensions.to_device,
to_token=to_token,
)
e2ee_response = None
e2ee_coro = None
if sync_config.extensions.e2ee is not None:
e2ee_response = await self.get_e2ee_extension_response(
e2ee_coro = self.get_e2ee_extension_response(
sync_config=sync_config,
e2ee_request=sync_config.extensions.e2ee,
to_token=to_token,
from_token=from_token,
)
account_data_response = None
account_data_coro = None
if sync_config.extensions.account_data is not None:
account_data_response = await self.get_account_data_extension_response(
account_data_coro = self.get_account_data_extension_response(
sync_config=sync_config,
previous_connection_state=previous_connection_state,
new_connection_state=new_connection_state,
@ -127,9 +130,9 @@ class SlidingSyncExtensionHandler:
from_token=from_token,
)
receipts_response = None
receipts_coro = None
if sync_config.extensions.receipts is not None:
receipts_response = await self.get_receipts_extension_response(
receipts_coro = self.get_receipts_extension_response(
sync_config=sync_config,
previous_connection_state=previous_connection_state,
new_connection_state=new_connection_state,
@ -141,9 +144,9 @@ class SlidingSyncExtensionHandler:
from_token=from_token,
)
typing_response = None
typing_coro = None
if sync_config.extensions.typing is not None:
typing_response = await self.get_typing_extension_response(
typing_coro = self.get_typing_extension_response(
sync_config=sync_config,
actual_lists=actual_lists,
actual_room_ids=actual_room_ids,
@ -153,6 +156,20 @@ class SlidingSyncExtensionHandler:
from_token=from_token,
)
(
to_device_response,
e2ee_response,
account_data_response,
receipts_response,
typing_response,
) = await gather_optional_coroutines(
to_device_coro,
e2ee_coro,
account_data_coro,
receipts_coro,
typing_coro,
)
return SlidingSyncResult.Extensions(
to_device=to_device_response,
e2ee=e2ee_response,

View file

@ -37,6 +37,7 @@ import warnings
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Optional,
@ -850,6 +851,45 @@ def run_in_background(
return d
def run_coroutine_in_background(
coroutine: typing.Coroutine[Any, Any, R],
) -> "defer.Deferred[R]":
"""Run the coroutine, ensuring that the current context is restored after
return from the function, and that the sentinel context is set once the
deferred returned by the function completes.
Useful for wrapping coroutines that you don't yield or await on (for
instance because you want to pass it to deferred.gatherResults()).
This is a special case of `run_in_background` where we can accept a
coroutine directly rather than a function. We can do this because coroutines
do not run until called, and so calling an async function without awaiting
cannot change the log contexts.
"""
current = current_context()
d = defer.ensureDeferred(coroutine)
# The function may have reset the context before returning, so
# we need to restore it now.
ctx = set_current_context(current)
# The original context will be restored when the deferred
# completes, but there is nothing waiting for it, so it will
# get leaked into the reactor or some other function which
# wasn't expecting it. We therefore need to reset the context
# here.
#
# (If this feels asymmetric, consider it this way: we are
# effectively forking a new thread of execution. We are
# probably currently within a ``with LoggingContext()`` block,
# which is supposed to have a single entry and exit point. But
# by spawning off another deferred, we are effectively
# adding a new exit point.)
d.addBoth(_set_context_cb, ctx)
return d
T = TypeVar("T")

View file

@ -51,7 +51,7 @@ from typing import (
)
import attr
from typing_extensions import Concatenate, Literal, ParamSpec
from typing_extensions import Concatenate, Literal, ParamSpec, Unpack
from twisted.internet import defer
from twisted.internet.defer import CancelledError
@ -61,6 +61,7 @@ from twisted.python.failure import Failure
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
run_coroutine_in_background,
run_in_background,
)
from synapse.util import Clock
@ -344,6 +345,7 @@ T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")
T4 = TypeVar("T4")
T5 = TypeVar("T5")
@overload
@ -402,6 +404,112 @@ def gather_results( # type: ignore[misc]
return deferred.addCallback(tuple)
@overload
async def gather_optional_coroutines(
*coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]]]],
) -> Tuple[Optional[T1]]: ...
@overload
async def gather_optional_coroutines(
*coroutines: Unpack[
Tuple[
Optional[Coroutine[Any, Any, T1]],
Optional[Coroutine[Any, Any, T2]],
]
],
) -> Tuple[Optional[T1], Optional[T2]]: ...
@overload
async def gather_optional_coroutines(
*coroutines: Unpack[
Tuple[
Optional[Coroutine[Any, Any, T1]],
Optional[Coroutine[Any, Any, T2]],
Optional[Coroutine[Any, Any, T3]],
]
],
) -> Tuple[Optional[T1], Optional[T2], Optional[T3]]: ...
@overload
async def gather_optional_coroutines(
*coroutines: Unpack[
Tuple[
Optional[Coroutine[Any, Any, T1]],
Optional[Coroutine[Any, Any, T2]],
Optional[Coroutine[Any, Any, T3]],
Optional[Coroutine[Any, Any, T4]],
]
],
) -> Tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4]]: ...
@overload
async def gather_optional_coroutines(
*coroutines: Unpack[
Tuple[
Optional[Coroutine[Any, Any, T1]],
Optional[Coroutine[Any, Any, T2]],
Optional[Coroutine[Any, Any, T3]],
Optional[Coroutine[Any, Any, T4]],
Optional[Coroutine[Any, Any, T5]],
]
],
) -> Tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4], Optional[T5]]: ...
async def gather_optional_coroutines(
*coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]], ...]],
) -> Tuple[Optional[T1], ...]:
"""Helper function that allows waiting on multiple coroutines at once.
The return value is a tuple of the return values of the coroutines in order.
If a `None` is passed instead of a coroutine, it will be ignored and a None
is returned in the tuple.
Note: For typechecking we need to have an explicit overload for each
distinct number of coroutines passed in. If you see type problems, it's
likely because you're using many arguments and you need to add a new
overload above.
"""
try:
results = await make_deferred_yieldable(
defer.gatherResults(
[
run_coroutine_in_background(coroutine)
for coroutine in coroutines
if coroutine is not None
],
consumeErrors=True,
)
)
results_iter = iter(results)
return tuple(
next(results_iter) if coroutine is not None else None
for coroutine in coroutines
)
except defer.FirstError as dfe:
# unwrap the error from defer.gatherResults.
# The raised exception's traceback only includes func() etc if
# the 'await' happens before the exception is thrown - ie if the failure
# happens *asynchronously* - otherwise Twisted throws away the traceback as it
# could be large.
#
# We could maybe reconstruct a fake traceback from Failure.frames. Or maybe
# we could throw Twisted into the fires of Mordor.
# suppress exception chaining, because the FirstError doesn't tell us anything
# very interesting.
assert isinstance(dfe.subFailure.value, BaseException)
raise dfe.subFailure.value from None
@attr.s(slots=True, auto_attribs=True)
class _LinearizerEntry:
# The number of things executing.

View file

@ -18,7 +18,7 @@
#
#
import traceback
from typing import Generator, List, NoReturn, Optional
from typing import Any, Coroutine, Generator, List, NoReturn, Optional, Tuple, TypeVar
from parameterized import parameterized_class
@ -39,6 +39,7 @@ from synapse.util.async_helpers import (
ObservableDeferred,
concurrently_execute,
delay_cancellation,
gather_optional_coroutines,
stop_cancellation,
timeout_deferred,
)
@ -46,6 +47,8 @@ from synapse.util.async_helpers import (
from tests.server import get_clock
from tests.unittest import TestCase
T = TypeVar("T")
class ObservableDeferredTest(TestCase):
def test_succeed(self) -> None:
@ -588,3 +591,106 @@ class AwakenableSleeperTests(TestCase):
sleeper.wake("name")
self.assertTrue(d1.called)
self.assertTrue(d2.called)
class GatherCoroutineTests(TestCase):
"""Tests for `gather_optional_coroutines`"""
def make_coroutine(self) -> Tuple[Coroutine[Any, Any, T], "defer.Deferred[T]"]:
"""Returns a coroutine and a deferred that it is waiting on to resolve"""
d: "defer.Deferred[T]" = defer.Deferred()
async def inner() -> T:
with PreserveLoggingContext():
return await d
return inner(), d
def test_single(self) -> None:
"Test passing in a single coroutine works"
with LoggingContext("test_ctx") as text_ctx:
deferred: "defer.Deferred[None]"
coroutine, deferred = self.make_coroutine()
gather_deferred = defer.ensureDeferred(
gather_optional_coroutines(coroutine)
)
# We shouldn't have a result yet, and should be in the sentinel
# context.
self.assertNoResult(gather_deferred)
self.assertEqual(current_context(), SENTINEL_CONTEXT)
# Resolving the deferred will resolve the coroutine
deferred.callback(None)
# All coroutines have resolved, and so we should have the results
result = self.successResultOf(gather_deferred)
self.assertEqual(result, (None,))
# We should be back in the normal context.
self.assertEqual(current_context(), text_ctx)
def test_multiple_resolve(self) -> None:
"Test passing in multiple coroutine that all resolve works"
with LoggingContext("test_ctx") as test_ctx:
deferred1: "defer.Deferred[int]"
coroutine1, deferred1 = self.make_coroutine()
deferred2: "defer.Deferred[str]"
coroutine2, deferred2 = self.make_coroutine()
gather_deferred = defer.ensureDeferred(
gather_optional_coroutines(coroutine1, coroutine2)
)
# We shouldn't have a result yet, and should be in the sentinel
# context.
self.assertNoResult(gather_deferred)
self.assertEqual(current_context(), SENTINEL_CONTEXT)
# Even if we resolve one of the coroutines, we shouldn't have a result
# yet
deferred2.callback("test")
self.assertNoResult(gather_deferred)
self.assertEqual(current_context(), SENTINEL_CONTEXT)
deferred1.callback(1)
# All coroutines have resolved, and so we should have the results
result = self.successResultOf(gather_deferred)
self.assertEqual(result, (1, "test"))
# We should be back in the normal context.
self.assertEqual(current_context(), test_ctx)
def test_multiple_fail(self) -> None:
"Test passing in multiple coroutine where one fails does the right thing"
with LoggingContext("test_ctx") as test_ctx:
deferred1: "defer.Deferred[int]"
coroutine1, deferred1 = self.make_coroutine()
deferred2: "defer.Deferred[str]"
coroutine2, deferred2 = self.make_coroutine()
gather_deferred = defer.ensureDeferred(
gather_optional_coroutines(coroutine1, coroutine2)
)
# We shouldn't have a result yet, and should be in the sentinel
# context.
self.assertNoResult(gather_deferred)
self.assertEqual(current_context(), SENTINEL_CONTEXT)
# Throw an exception in one of the coroutines
exc = Exception("test")
deferred2.errback(exc)
# Expect the gather deferred to immediately fail
result_exc = self.failureResultOf(gather_deferred)
self.assertEqual(result_exc.value, exc)
# We should be back in the normal context.
self.assertEqual(current_context(), test_ctx)