Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

This commit is contained in:
Patrick Cloke 2023-11-08 07:45:34 -05:00
commit b77c9c3f73
19 changed files with 878 additions and 143 deletions

1
changelog.d/16532.misc Normal file
View file

@ -0,0 +1 @@
Support reactor tick timings on more types of event loops.

1
changelog.d/16583.misc Normal file
View file

@ -0,0 +1 @@
Avoid executing no-op queries.

1
changelog.d/16590.misc Normal file
View file

@ -0,0 +1 @@
Run push rule evaluator setup in parallel.

1
changelog.d/16596.misc Normal file
View file

@ -0,0 +1 @@
Improve tests of the SQL generator.

1
changelog.d/16605.misc Normal file
View file

@ -0,0 +1 @@
Bump setuptools-rust from 1.8.0 to 1.8.1.

1
changelog.d/16609.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a long-standing bug where some queries updated the same row twice. Introduced in Synapse 1.57.0.

View file

@ -37,8 +37,8 @@ files =
build_rust.py build_rust.py
[mypy-synapse.metrics._reactor_metrics] [mypy-synapse.metrics._reactor_metrics]
# This module imports select.epoll. That exists on Linux, but doesn't on macOS. # This module pokes at the internals of OS-specific classes, to appease mypy
# See https://github.com/matrix-org/synapse/pull/11771. # on different systems we add additional ignores.
warn_unused_ignores = False warn_unused_ignores = False
[mypy-synapse.util.caches.treecache] [mypy-synapse.util.caches.treecache]

58
poetry.lock generated
View file

@ -2439,28 +2439,28 @@ files = [
[[package]] [[package]]
name = "ruff" name = "ruff"
version = "0.0.292" version = "0.1.4"
description = "An extremely fast Python linter, written in Rust." description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "ruff-0.0.292-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:02f29db018c9d474270c704e6c6b13b18ed0ecac82761e4fcf0faa3728430c96"}, {file = "ruff-0.1.4-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:864958706b669cce31d629902175138ad8a069d99ca53514611521f532d91495"},
{file = "ruff-0.0.292-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:69654e564342f507edfa09ee6897883ca76e331d4bbc3676d8a8403838e9fade"}, {file = "ruff-0.1.4-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:9fdd61883bb34317c788af87f4cd75dfee3a73f5ded714b77ba928e418d6e39e"},
{file = "ruff-0.0.292-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c3c91859a9b845c33778f11902e7b26440d64b9d5110edd4e4fa1726c41e0a4"}, {file = "ruff-0.1.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b4eaca8c9cc39aa7f0f0d7b8fe24ecb51232d1bb620fc4441a61161be4a17539"},
{file = "ruff-0.0.292-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f4476f1243af2d8c29da5f235c13dca52177117935e1f9393f9d90f9833f69e4"}, {file = "ruff-0.1.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a9a1301dc43cbf633fb603242bccd0aaa34834750a14a4c1817e2e5c8d60de17"},
{file = "ruff-0.0.292-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:be8eb50eaf8648070b8e58ece8e69c9322d34afe367eec4210fdee9a555e4ca7"}, {file = "ruff-0.1.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:78e8db8ab6f100f02e28b3d713270c857d370b8d61871d5c7d1702ae411df683"},
{file = "ruff-0.0.292-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:9889bac18a0c07018aac75ef6c1e6511d8411724d67cb879103b01758e110a81"}, {file = "ruff-0.1.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:80fea754eaae06335784b8ea053d6eb8e9aac75359ebddd6fee0858e87c8d510"},
{file = "ruff-0.0.292-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6bdfabd4334684a4418b99b3118793f2c13bb67bf1540a769d7816410402a205"}, {file = "ruff-0.1.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6bc02a480d4bfffd163a723698da15d1a9aec2fced4c06f2a753f87f4ce6969c"},
{file = "ruff-0.0.292-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aa7c77c53bfcd75dbcd4d1f42d6cabf2485d2e1ee0678da850f08e1ab13081a8"}, {file = "ruff-0.1.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9862811b403063765b03e716dac0fda8fdbe78b675cd947ed5873506448acea4"},
{file = "ruff-0.0.292-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e087b24d0d849c5c81516ec740bf4fd48bf363cfb104545464e0fca749b6af9"}, {file = "ruff-0.1.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58826efb8b3efbb59bb306f4b19640b7e366967a31c049d49311d9eb3a4c60cb"},
{file = "ruff-0.0.292-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:f160b5ec26be32362d0774964e218f3fcf0a7da299f7e220ef45ae9e3e67101a"}, {file = "ruff-0.1.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:fdfd453fc91d9d86d6aaa33b1bafa69d114cf7421057868f0b79104079d3e66e"},
{file = "ruff-0.0.292-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:ac153eee6dd4444501c4bb92bff866491d4bfb01ce26dd2fff7ca472c8df9ad0"}, {file = "ruff-0.1.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:e8791482d508bd0b36c76481ad3117987301b86072158bdb69d796503e1c84a8"},
{file = "ruff-0.0.292-py3-none-musllinux_1_2_i686.whl", hash = "sha256:87616771e72820800b8faea82edd858324b29bb99a920d6aa3d3949dd3f88fb0"}, {file = "ruff-0.1.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:01206e361021426e3c1b7fba06ddcb20dbc5037d64f6841e5f2b21084dc51800"},
{file = "ruff-0.0.292-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b76deb3bdbea2ef97db286cf953488745dd6424c122d275f05836c53f62d4016"}, {file = "ruff-0.1.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:645591a613a42cb7e5c2b667cbefd3877b21e0252b59272ba7212c3d35a5819f"},
{file = "ruff-0.0.292-py3-none-win32.whl", hash = "sha256:e854b05408f7a8033a027e4b1c7f9889563dd2aca545d13d06711e5c39c3d003"}, {file = "ruff-0.1.4-py3-none-win32.whl", hash = "sha256:99908ca2b3b85bffe7e1414275d004917d1e0dfc99d497ccd2ecd19ad115fd0d"},
{file = "ruff-0.0.292-py3-none-win_amd64.whl", hash = "sha256:f27282bedfd04d4c3492e5c3398360c9d86a295be00eccc63914438b4ac8a83c"}, {file = "ruff-0.1.4-py3-none-win_amd64.whl", hash = "sha256:1dfd6bf8f6ad0a4ac99333f437e0ec168989adc5d837ecd38ddb2cc4a2e3db8a"},
{file = "ruff-0.0.292-py3-none-win_arm64.whl", hash = "sha256:7f67a69c8f12fbc8daf6ae6d36705037bde315abf8b82b6e1f4c9e74eb750f68"}, {file = "ruff-0.1.4-py3-none-win_arm64.whl", hash = "sha256:d98ae9ebf56444e18a3e3652b3383204748f73e247dea6caaf8b52d37e6b32da"},
{file = "ruff-0.0.292.tar.gz", hash = "sha256:1093449e37dd1e9b813798f6ad70932b57cf614e5c2b5c51005bf67d55db33ac"}, {file = "ruff-0.1.4.tar.gz", hash = "sha256:21520ecca4cc555162068d87c747b8f95e1e95f8ecfcbbe59e8dd00710586315"},
] ]
[[package]] [[package]]
@ -2580,13 +2580,13 @@ testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (
[[package]] [[package]]
name = "setuptools-rust" name = "setuptools-rust"
version = "1.8.0" version = "1.8.1"
description = "Setuptools Rust extension plugin" description = "Setuptools Rust extension plugin"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "setuptools-rust-1.8.0.tar.gz", hash = "sha256:5e02b7a80058853bf64127314f6b97d0efed11e08b94c88ca639a20976f6adc4"}, {file = "setuptools-rust-1.8.1.tar.gz", hash = "sha256:94b1dd5d5308b3138d5b933c3a2b55e6d6927d1a22632e509fcea9ddd0f7e486"},
{file = "setuptools_rust-1.8.0-py3-none-any.whl", hash = "sha256:95ec67edee2ca73233c9e75250e9d23a302aa23b4c8413dfd19c14c30d08f703"}, {file = "setuptools_rust-1.8.1-py3-none-any.whl", hash = "sha256:b5324493949ccd6aa0c03890c5f6b5f02de4512e3ac1697d02e9a6c02b18aa8e"},
] ]
[package.dependencies] [package.dependencies]
@ -3069,13 +3069,13 @@ files = [
[[package]] [[package]]
name = "types-jsonschema" name = "types-jsonschema"
version = "4.19.0.3" version = "4.19.0.4"
description = "Typing stubs for jsonschema" description = "Typing stubs for jsonschema"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "types-jsonschema-4.19.0.3.tar.gz", hash = "sha256:e0fc0f5d51fd0988bf193be42174a5376b0096820ff79505d9c1b66de23f0581"}, {file = "types-jsonschema-4.19.0.4.tar.gz", hash = "sha256:994feb6632818259c4b5dbd733867824cb475029a6abc2c2b5201a2268b6e7d2"},
{file = "types_jsonschema-4.19.0.3-py3-none-any.whl", hash = "sha256:5cedbb661e5ca88d95b94b79902423e3f97a389c245e5fe0ab384122f27d56b9"}, {file = "types_jsonschema-4.19.0.4-py3-none-any.whl", hash = "sha256:b73c3f4ba3cd8108602d1198a438e2698d5eb6b9db206ed89a33e24729b0abe7"},
] ]
[package.dependencies] [package.dependencies]
@ -3141,13 +3141,13 @@ cryptography = ">=35.0.0"
[[package]] [[package]]
name = "types-pyyaml" name = "types-pyyaml"
version = "6.0.12.11" version = "6.0.12.12"
description = "Typing stubs for PyYAML" description = "Typing stubs for PyYAML"
optional = false optional = false
python-versions = "*" python-versions = "*"
files = [ files = [
{file = "types-PyYAML-6.0.12.11.tar.gz", hash = "sha256:7d340b19ca28cddfdba438ee638cd4084bde213e501a3978738543e27094775b"}, {file = "types-PyYAML-6.0.12.12.tar.gz", hash = "sha256:334373d392fde0fdf95af5c3f1661885fa10c52167b14593eb856289e1855062"},
{file = "types_PyYAML-6.0.12.11-py3-none-any.whl", hash = "sha256:a461508f3096d1d5810ec5ab95d7eeecb651f3a15b71959999988942063bf01d"}, {file = "types_PyYAML-6.0.12.12-py3-none-any.whl", hash = "sha256:c05bc6c158facb0676674b7f11fe3960db4f389718e19e62bd2b84d6205cfd24"},
] ]
[[package]] [[package]]
@ -3447,4 +3447,4 @@ user-search = ["pyicu"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.8.0" python-versions = "^3.8.0"
content-hash = "a08543c65f18cc7e9dea648e89c18ab88fc1747aa2e029aa208f777fc3db06dd" content-hash = "369455d6a67753a6bcfbad3cd86801b1dd02896d0180080e2ba9501e007353ec"

View file

@ -321,7 +321,7 @@ all = [
# This helps prevents merge conflicts when running a batch of dependabot updates. # This helps prevents merge conflicts when running a batch of dependabot updates.
isort = ">=5.10.1" isort = ">=5.10.1"
black = ">=22.7.0" black = ">=22.7.0"
ruff = "0.0.292" ruff = "0.1.4"
# Type checking only works with the pydantic.v1 compat module from pydantic v2 # Type checking only works with the pydantic.v1 compat module from pydantic v2
pydantic = "^2" pydantic = "^2"
@ -381,7 +381,7 @@ furo = ">=2022.12.7,<2024.0.0"
# system changes. # system changes.
# We are happy to raise these upper bounds upon request, # We are happy to raise these upper bounds upon request,
# provided we check that it's safe to do so (i.e. that CI passes). # provided we check that it's safe to do so (i.e. that CI passes).
requires = ["poetry-core>=1.1.0,<=1.7.0", "setuptools_rust>=1.3,<=1.8.0"] requires = ["poetry-core>=1.1.0,<=1.7.0", "setuptools_rust>=1.3,<=1.8.1"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"

View file

@ -12,17 +12,45 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import select import logging
import time import time
from typing import Any, Iterable, List, Tuple from selectors import SelectSelector, _PollLikeSelector # type: ignore[attr-defined]
from typing import Any, Callable, Iterable
from prometheus_client import Histogram, Metric from prometheus_client import Histogram, Metric
from prometheus_client.core import REGISTRY, GaugeMetricFamily from prometheus_client.core import REGISTRY, GaugeMetricFamily
from twisted.internet import reactor from twisted.internet import reactor, selectreactor
from twisted.internet.asyncioreactor import AsyncioSelectorReactor
from synapse.metrics._types import Collector from synapse.metrics._types import Collector
try:
from selectors import KqueueSelector
except ImportError:
class KqueueSelector: # type: ignore[no-redef]
pass
try:
from twisted.internet.epollreactor import EPollReactor
except ImportError:
class EPollReactor: # type: ignore[no-redef]
pass
try:
from twisted.internet.pollreactor import PollReactor
except ImportError:
class PollReactor: # type: ignore[no-redef]
pass
logger = logging.getLogger(__name__)
# #
# Twisted reactor metrics # Twisted reactor metrics
# #
@ -34,52 +62,100 @@ tick_time = Histogram(
) )
class EpollWrapper: class CallWrapper:
"""a wrapper for an epoll object which records the time between polls""" """A wrapper for a callable which records the time between calls"""
def __init__(self, poller: "select.epoll"): # type: ignore[name-defined] def __init__(self, wrapped: Callable[..., Any]):
self.last_polled = time.time() self.last_polled = time.time()
self._poller = poller self._wrapped = wrapped
def poll(self, *args, **kwargs) -> List[Tuple[int, int]]: # type: ignore[no-untyped-def] def __call__(self, *args, **kwargs) -> Any: # type: ignore[no-untyped-def]
# record the time since poll() was last called. This gives a good proxy for # record the time since this was last called. This gives a good proxy for
# how long it takes to run everything in the reactor - ie, how long anything # how long it takes to run everything in the reactor - ie, how long anything
# waiting for the next tick will have to wait. # waiting for the next tick will have to wait.
tick_time.observe(time.time() - self.last_polled) tick_time.observe(time.time() - self.last_polled)
ret = self._poller.poll(*args, **kwargs) ret = self._wrapped(*args, **kwargs)
self.last_polled = time.time() self.last_polled = time.time()
return ret return ret
class ObjWrapper:
"""A wrapper for an object which wraps a specified method in CallWrapper.
Other methods/attributes are passed to the original object.
This is necessary when the wrapped object does not allow the attribute to be
overwritten.
"""
def __init__(self, wrapped: Any, method_name: str):
self._wrapped = wrapped
self._method_name = method_name
self._wrapped_method = CallWrapper(getattr(wrapped, method_name))
def __getattr__(self, item: str) -> Any: def __getattr__(self, item: str) -> Any:
return getattr(self._poller, item) if item == self._method_name:
return self._wrapped_method
return getattr(self._wrapped, item)
class ReactorLastSeenMetric(Collector): class ReactorLastSeenMetric(Collector):
def __init__(self, epoll_wrapper: EpollWrapper): def __init__(self, call_wrapper: CallWrapper):
self._epoll_wrapper = epoll_wrapper self._call_wrapper = call_wrapper
def collect(self) -> Iterable[Metric]: def collect(self) -> Iterable[Metric]:
cm = GaugeMetricFamily( cm = GaugeMetricFamily(
"python_twisted_reactor_last_seen", "python_twisted_reactor_last_seen",
"Seconds since the Twisted reactor was last seen", "Seconds since the Twisted reactor was last seen",
) )
cm.add_metric([], time.time() - self._epoll_wrapper.last_polled) cm.add_metric([], time.time() - self._call_wrapper.last_polled)
yield cm yield cm
# Twisted has already select a reasonable reactor for us, so assumptions can be
# made about the shape.
wrapper = None
try: try:
# if the reactor has a `_poller` attribute, which is an `epoll` object if isinstance(reactor, (PollReactor, EPollReactor)):
# (ie, it's an EPollReactor), we wrap the `epoll` with a thing that will reactor._poller = ObjWrapper(reactor._poller, "poll") # type: ignore[attr-defined]
# measure the time between ticks wrapper = reactor._poller._wrapped_method # type: ignore[attr-defined]
from select import epoll # type: ignore[attr-defined]
poller = reactor._poller # type: ignore[attr-defined] elif isinstance(reactor, selectreactor.SelectReactor):
except (AttributeError, ImportError): # Twisted uses a module-level _select function.
pass wrapper = selectreactor._select = CallWrapper(selectreactor._select)
else:
if isinstance(poller, epoll): elif isinstance(reactor, AsyncioSelectorReactor):
poller = EpollWrapper(poller) # For asyncio look at the underlying asyncio event loop.
reactor._poller = poller # type: ignore[attr-defined] asyncio_loop = reactor._asyncioEventloop # A sub-class of BaseEventLoop,
REGISTRY.register(ReactorLastSeenMetric(poller))
# A sub-class of BaseSelector.
selector = asyncio_loop._selector # type: ignore[attr-defined]
if isinstance(selector, SelectSelector):
wrapper = selector._select = CallWrapper(selector._select) # type: ignore[attr-defined]
# poll, epoll, and /dev/poll.
elif isinstance(selector, _PollLikeSelector):
selector._selector = ObjWrapper(selector._selector, "poll") # type: ignore[attr-defined]
wrapper = selector._selector._wrapped_method # type: ignore[attr-defined]
elif isinstance(selector, KqueueSelector):
selector._selector = ObjWrapper(selector._selector, "control") # type: ignore[attr-defined]
wrapper = selector._selector._wrapped_method # type: ignore[attr-defined]
else:
# E.g. this does not support the (Windows-only) ProactorEventLoop.
logger.warning(
"Skipping configuring ReactorLastSeenMetric: unexpected asyncio loop selector: %r via %r",
selector,
asyncio_loop,
)
except Exception as e:
logger.warning("Configuring ReactorLastSeenMetric failed: %r", e)
if wrapper:
REGISTRY.register(ReactorLastSeenMetric(wrapper))

View file

@ -25,10 +25,13 @@ from typing import (
Sequence, Sequence,
Tuple, Tuple,
Union, Union,
cast,
) )
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet.defer import Deferred
from synapse.api.constants import ( from synapse.api.constants import (
MAIN_TIMELINE, MAIN_TIMELINE,
EventContentFields, EventContentFields,
@ -40,11 +43,15 @@ from synapse.api.room_versions import PushRuleRoomFlag
from synapse.event_auth import auth_types_for_event, get_user_power_level from synapse.event_auth import auth_types_for_event, get_user_power_level
from synapse.events import EventBase, relation_from_event from synapse.events import EventBase, relation_from_event
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.state import POWER_KEY from synapse.state import POWER_KEY
from synapse.storage.databases.main.roommember import EventIdMembership from synapse.storage.databases.main.roommember import EventIdMembership
from synapse.storage.roommember import ProfileInfo
from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator
from synapse.types import JsonValue from synapse.types import JsonValue
from synapse.types.state import StateFilter from synapse.types.state import StateFilter
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import gather_results
from synapse.util.caches import register_cache from synapse.util.caches import register_cache
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.visibility import filter_event_for_clients_with_state from synapse.visibility import filter_event_for_clients_with_state
@ -342,15 +349,41 @@ class BulkPushRuleEvaluator:
rules_by_user = await self._get_rules_for_event(event) rules_by_user = await self._get_rules_for_event(event)
actions_by_user: Dict[str, Collection[Union[Mapping, str]]] = {} actions_by_user: Dict[str, Collection[Union[Mapping, str]]] = {}
room_member_count = await self.store.get_number_joined_users_in_room( # Gather a bunch of info in parallel.
event.room_id #
) # This has a lot of ignored types and casting due to the use of @cached
# decorated functions passed into run_in_background.
#
# See https://github.com/matrix-org/synapse/issues/16606
( (
power_levels, room_member_count,
sender_power_level, (power_levels, sender_power_level),
) = await self._get_power_levels_and_sender_level( related_events,
event, context, event_id_to_event profiles,
) = await make_deferred_yieldable(
cast(
"Deferred[Tuple[int, Tuple[dict, Optional[int]], Dict[str, Dict[str, JsonValue]], Mapping[str, ProfileInfo]]]",
gather_results(
(
run_in_background( # type: ignore[call-arg]
self.store.get_number_joined_users_in_room, event.room_id # type: ignore[arg-type]
),
run_in_background(
self._get_power_levels_and_sender_level,
event,
context,
event_id_to_event,
),
run_in_background(self._related_events, event),
run_in_background( # type: ignore[call-arg]
self.store.get_subset_users_in_room_with_profiles,
event.room_id, # type: ignore[arg-type]
rules_by_user.keys(), # type: ignore[arg-type]
),
),
consumeErrors=True,
).addErrback(unwrapFirstError),
)
) )
# Find the event's thread ID. # Find the event's thread ID.
@ -366,8 +399,6 @@ class BulkPushRuleEvaluator:
# the parent is part of a thread. # the parent is part of a thread.
thread_id = await self.store.get_thread_id(relation.parent_id) thread_id = await self.store.get_thread_id(relation.parent_id)
related_events = await self._related_events(event)
# It's possible that old room versions have non-integer power levels (floats or # It's possible that old room versions have non-integer power levels (floats or
# strings; even the occasional `null`). For old rooms, we interpret these as if # strings; even the occasional `null`). For old rooms, we interpret these as if
# they were integers. Do this here for the `@room` power level threshold. # they were integers. Do this here for the `@room` power level threshold.
@ -400,11 +431,6 @@ class BulkPushRuleEvaluator:
self.hs.config.experimental.msc1767_enabled, # MSC3931 flag self.hs.config.experimental.msc1767_enabled, # MSC3931 flag
) )
users = rules_by_user.keys()
profiles = await self.store.get_subset_users_in_room_with_profiles(
event.room_id, users
)
for uid, rules in rules_by_user.items(): for uid, rules in rules_by_user.items():
if event.sender == uid: if event.sender == uid:
continue continue

View file

@ -1117,7 +1117,7 @@ class DatabasePool:
txn: LoggingTransaction, txn: LoggingTransaction,
table: str, table: str,
keys: Collection[str], keys: Collection[str],
values: Iterable[Iterable[Any]], values: Collection[Iterable[Any]],
) -> None: ) -> None:
"""Executes an INSERT query on the named table. """Executes an INSERT query on the named table.
@ -1130,6 +1130,9 @@ class DatabasePool:
keys: list of column names keys: list of column names
values: for each row, a list of values in the same order as `keys` values: for each row, a list of values in the same order as `keys`
""" """
# If there's nothing to insert, then skip executing the query.
if not values:
return
if isinstance(txn.database_engine, PostgresEngine): if isinstance(txn.database_engine, PostgresEngine):
# We use `execute_values` as it can be a lot faster than `execute_batch`, # We use `execute_values` as it can be a lot faster than `execute_batch`,
@ -1401,12 +1404,12 @@ class DatabasePool:
allvalues.update(values) allvalues.update(values)
latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %s DO %s" % ( sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %sDO %s" % (
table, table,
", ".join(k for k in allvalues), ", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues), ", ".join("?" for _ in allvalues),
", ".join(k for k in keyvalues), ", ".join(k for k in keyvalues),
f"WHERE {where_clause}" if where_clause else "", f"WHERE {where_clause} " if where_clause else "",
latter, latter,
) )
txn.execute(sql, list(allvalues.values())) txn.execute(sql, list(allvalues.values()))
@ -1455,7 +1458,7 @@ class DatabasePool:
key_names: Collection[str], key_names: Collection[str],
key_values: Collection[Iterable[Any]], key_values: Collection[Iterable[Any]],
value_names: Collection[str], value_names: Collection[str],
value_values: Iterable[Iterable[Any]], value_values: Collection[Iterable[Any]],
) -> None: ) -> None:
""" """
Upsert, many times. Upsert, many times.
@ -1468,6 +1471,19 @@ class DatabasePool:
value_values: A list of each row's value column values. value_values: A list of each row's value column values.
Ignored if value_names is empty. Ignored if value_names is empty.
""" """
# If there's nothing to upsert, then skip executing the query.
if not key_values:
return
# No value columns, therefore make a blank list so that the following
# zip() works correctly.
if not value_names:
value_values = [() for x in range(len(key_values))]
elif len(value_values) != len(key_values):
raise ValueError(
f"{len(key_values)} key rows and {len(value_values)} value rows: should be the same number."
)
if table not in self._unsafe_to_upsert_tables: if table not in self._unsafe_to_upsert_tables:
return self.simple_upsert_many_txn_native_upsert( return self.simple_upsert_many_txn_native_upsert(
txn, table, key_names, key_values, value_names, value_values txn, table, key_names, key_values, value_names, value_values
@ -1502,10 +1518,6 @@ class DatabasePool:
value_values: A list of each row's value column values. value_values: A list of each row's value column values.
Ignored if value_names is empty. Ignored if value_names is empty.
""" """
# No value columns, therefore make a blank list so that the following
# zip() works correctly.
if not value_names:
value_values = [() for x in range(len(key_values))]
# Lock the table just once, to prevent it being done once per row. # Lock the table just once, to prevent it being done once per row.
# Note that, according to Postgres' documentation, once obtained, # Note that, according to Postgres' documentation, once obtained,
@ -1543,10 +1555,7 @@ class DatabasePool:
allnames.extend(value_names) allnames.extend(value_names)
if not value_names: if not value_names:
# No value columns, therefore make a blank list so that the
# following zip() works correctly.
latter = "NOTHING" latter = "NOTHING"
value_values = [() for x in range(len(key_values))]
else: else:
latter = "UPDATE SET " + ", ".join( latter = "UPDATE SET " + ", ".join(
k + "=EXCLUDED." + k for k in value_names k + "=EXCLUDED." + k for k in value_names
@ -1910,6 +1919,7 @@ class DatabasePool:
Returns: Returns:
The results as a list of tuples. The results as a list of tuples.
""" """
# If there's nothing to select, then skip executing the query.
if not iterable: if not iterable:
return [] return []
@ -2044,13 +2054,13 @@ class DatabasePool:
raise ValueError( raise ValueError(
f"{len(key_values)} key rows and {len(value_values)} value rows: should be the same number." f"{len(key_values)} key rows and {len(value_values)} value rows: should be the same number."
) )
# If there is nothing to update, then skip executing the query.
if not key_values:
return
# List of tuples of (value values, then key values) # List of tuples of (value values, then key values)
# (This matches the order needed for the query) # (This matches the order needed for the query)
args = [tuple(x) + tuple(y) for x, y in zip(value_values, key_values)] args = [tuple(vv) + tuple(kv) for vv, kv in zip(value_values, key_values)]
for ks, vs in zip(key_values, value_values):
args.append(tuple(vs) + tuple(ks))
# 'col1 = ?, col2 = ?, ...' # 'col1 = ?, col2 = ?, ...'
set_clause = ", ".join(f"{n} = ?" for n in value_names) set_clause = ", ".join(f"{n} = ?" for n in value_names)
@ -2062,9 +2072,7 @@ class DatabasePool:
where_clause = "" where_clause = ""
# UPDATE mytable SET col1 = ?, col2 = ? WHERE col3 = ? AND col4 = ? # UPDATE mytable SET col1 = ?, col2 = ? WHERE col3 = ? AND col4 = ?
sql = f""" sql = f"UPDATE {table} SET {set_clause} {where_clause}"
UPDATE {table} SET {set_clause} {where_clause}
"""
txn.execute_batch(sql, args) txn.execute_batch(sql, args)
@ -2280,11 +2288,10 @@ class DatabasePool:
Returns: Returns:
Number rows deleted Number rows deleted
""" """
# If there's nothing to delete, then skip executing the query.
if not values: if not values:
return 0 return 0
sql = "DELETE FROM %s" % table
clause, values = make_in_list_sql_clause(txn.database_engine, column, values) clause, values = make_in_list_sql_clause(txn.database_engine, column, values)
clauses = [clause] clauses = [clause]
@ -2292,8 +2299,7 @@ class DatabasePool:
clauses.append("%s = ?" % (key,)) clauses.append("%s = ?" % (key,))
values.append(value) values.append(value)
if clauses: sql = "DELETE FROM %s WHERE %s" % (table, " AND ".join(clauses))
sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
txn.execute(sql, values) txn.execute(sql, values)
return txn.rowcount return txn.rowcount

View file

@ -705,7 +705,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
key_names=("destination", "user_id"), key_names=("destination", "user_id"),
key_values=[(destination, user_id) for user_id, _ in rows], key_values=[(destination, user_id) for user_id, _ in rows],
value_names=("stream_id",), value_names=("stream_id",),
value_values=((stream_id,) for _, stream_id in rows), value_values=[(stream_id,) for _, stream_id in rows],
) )
# Delete all sent outbound pokes # Delete all sent outbound pokes

View file

@ -1476,7 +1476,7 @@ class PersistEventsStore:
txn, txn,
table="event_json", table="event_json",
keys=("event_id", "room_id", "internal_metadata", "json", "format_version"), keys=("event_id", "room_id", "internal_metadata", "json", "format_version"),
values=( values=[
( (
event.event_id, event.event_id,
event.room_id, event.room_id,
@ -1485,7 +1485,7 @@ class PersistEventsStore:
event.format_version, event.format_version,
) )
for event, _ in events_and_contexts for event, _ in events_and_contexts
), ],
) )
self.db_pool.simple_insert_many_txn( self.db_pool.simple_insert_many_txn(
@ -1508,7 +1508,7 @@ class PersistEventsStore:
"state_key", "state_key",
"rejection_reason", "rejection_reason",
), ),
values=( values=[
( (
self._instance_name, self._instance_name,
event.internal_metadata.stream_ordering, event.internal_metadata.stream_ordering,
@ -1527,7 +1527,7 @@ class PersistEventsStore:
context.rejected, context.rejected,
) )
for event, context in events_and_contexts for event, context in events_and_contexts
), ],
) )
# If we're persisting an unredacted event we go and ensure # If we're persisting an unredacted event we go and ensure
@ -1550,11 +1550,11 @@ class PersistEventsStore:
txn, txn,
table="state_events", table="state_events",
keys=("event_id", "room_id", "type", "state_key"), keys=("event_id", "room_id", "type", "state_key"),
values=( values=[
(event.event_id, event.room_id, event.type, event.state_key) (event.event_id, event.room_id, event.type, event.state_key)
for event, _ in events_and_contexts for event, _ in events_and_contexts
if event.is_state() if event.is_state()
), ],
) )
def _store_rejected_events_txn( def _store_rejected_events_txn(

View file

@ -28,8 +28,11 @@ from typing import (
cast, cast,
) )
from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.config.homeserver import ExperimentalConfig from synapse.config.homeserver import ExperimentalConfig
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.replication.tcp.streams import PushRulesStream from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import ( from synapse.storage.database import (
@ -51,7 +54,8 @@ from synapse.storage.util.id_generators import (
) )
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder, unwrapFirstError
from synapse.util.async_helpers import gather_results
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -249,23 +253,33 @@ class PushRulesWorkerStore(
user_id: [] for user_id in user_ids user_id: [] for user_id in user_ids
} }
rows = cast( # gatherResults loses all type information.
List[Tuple[str, str, int, int, str, str]], rows, enabled_map_by_user = await make_deferred_yieldable(
await self.db_pool.simple_select_many_batch( gather_results(
table="push_rules", (
column="user_name", cast(
iterable=user_ids, "defer.Deferred[List[Tuple[str, str, int, int, str, str]]]",
retcols=( run_in_background(
"user_name", self.db_pool.simple_select_many_batch,
"rule_id", table="push_rules",
"priority_class", column="user_name",
"priority", iterable=user_ids,
"conditions", retcols=(
"actions", "user_name",
"rule_id",
"priority_class",
"priority",
"conditions",
"actions",
),
desc="bulk_get_push_rules",
batch_size=1000,
),
),
run_in_background(self.bulk_get_push_rules_enabled, user_ids),
), ),
desc="bulk_get_push_rules", consumeErrors=True,
batch_size=1000, ).addErrback(unwrapFirstError)
),
) )
# Sort by highest priority_class, then highest priority. # Sort by highest priority_class, then highest priority.
@ -276,8 +290,6 @@ class PushRulesWorkerStore(
(rule_id, priority_class, conditions, actions) (rule_id, priority_class, conditions, actions)
) )
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
results: Dict[str, FilteredPushRules] = {} results: Dict[str, FilteredPushRules] = {}
for user_id, rules in raw_rules.items(): for user_id, rules in raw_rules.items():

View file

@ -2268,7 +2268,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
txn, txn,
table="partial_state_rooms_servers", table="partial_state_rooms_servers",
keys=("room_id", "server_name"), keys=("room_id", "server_name"),
values=((room_id, s) for s in servers), values=[(room_id, s) for s in servers],
) )
self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,)) self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(

View file

@ -106,7 +106,7 @@ class SearchWorkerStore(SQLBaseStore):
txn, txn,
table="event_search", table="event_search",
keys=("event_id", "room_id", "key", "value"), keys=("event_id", "room_id", "key", "value"),
values=( values=[
( (
entry.event_id, entry.event_id,
entry.room_id, entry.room_id,
@ -114,7 +114,7 @@ class SearchWorkerStore(SQLBaseStore):
_clean_value_for_search(entry.value), _clean_value_for_search(entry.value),
) )
for entry in entries for entry in entries
), ],
) )
else: else:

View file

@ -345,6 +345,7 @@ async def yieldable_gather_results_delaying_cancellation(
T1 = TypeVar("T1") T1 = TypeVar("T1")
T2 = TypeVar("T2") T2 = TypeVar("T2")
T3 = TypeVar("T3") T3 = TypeVar("T3")
T4 = TypeVar("T4")
@overload @overload
@ -380,6 +381,19 @@ def gather_results(
... ...
@overload
def gather_results(
deferredList: Tuple[
"defer.Deferred[T1]",
"defer.Deferred[T2]",
"defer.Deferred[T3]",
"defer.Deferred[T4]",
],
consumeErrors: bool = ...,
) -> "defer.Deferred[Tuple[T1, T2, T3, T4]]":
...
def gather_results( # type: ignore[misc] def gather_results( # type: ignore[misc]
deferredList: Tuple["defer.Deferred[T1]", ...], deferredList: Tuple["defer.Deferred[T1]", ...],
consumeErrors: bool = False, consumeErrors: bool = False,

View file

@ -14,7 +14,7 @@
from collections import OrderedDict from collections import OrderedDict
from typing import Generator from typing import Generator
from unittest.mock import Mock from unittest.mock import Mock, call, patch
from twisted.internet import defer from twisted.internet import defer
@ -24,43 +24,90 @@ from synapse.storage.engines import create_engine
from tests import unittest from tests import unittest
from tests.server import TestHomeServer from tests.server import TestHomeServer
from tests.utils import default_config from tests.utils import USE_POSTGRES_FOR_TESTS, default_config
class SQLBaseStoreTestCase(unittest.TestCase): class SQLBaseStoreTestCase(unittest.TestCase):
"""Test the "simple" SQL generating methods in SQLBaseStore.""" """Test the "simple" SQL generating methods in SQLBaseStore."""
def setUp(self) -> None: def setUp(self) -> None:
self.db_pool = Mock(spec=["runInteraction"]) # This is the Twisted connection pool.
conn_pool = Mock(spec=["runInteraction", "runWithConnection"])
self.mock_txn = Mock() self.mock_txn = Mock()
self.mock_conn = Mock(spec_set=["cursor", "rollback", "commit"]) if USE_POSTGRES_FOR_TESTS:
# To avoid testing psycopg2 itself, patch execute_batch/execute_values
# to assert how it is called.
from psycopg2 import extras
self.mock_execute_batch = Mock()
self.execute_batch_patcher = patch.object(
extras, "execute_batch", new=self.mock_execute_batch
)
self.execute_batch_patcher.start()
self.mock_execute_values = Mock()
self.execute_values_patcher = patch.object(
extras, "execute_values", new=self.mock_execute_values
)
self.execute_values_patcher.start()
self.mock_conn = Mock(
spec_set=[
"cursor",
"rollback",
"commit",
"closed",
"reconnect",
"set_session",
"encoding",
]
)
self.mock_conn.encoding = "UNICODE"
else:
self.mock_conn = Mock(spec_set=["cursor", "rollback", "commit"])
self.mock_conn.cursor.return_value = self.mock_txn self.mock_conn.cursor.return_value = self.mock_txn
self.mock_txn.connection = self.mock_conn
self.mock_conn.rollback.return_value = None self.mock_conn.rollback.return_value = None
# Our fake runInteraction just runs synchronously inline # Our fake runInteraction just runs synchronously inline
def runInteraction(func, *args, **kwargs) -> defer.Deferred: # type: ignore[no-untyped-def] def runInteraction(func, *args, **kwargs) -> defer.Deferred: # type: ignore[no-untyped-def]
return defer.succeed(func(self.mock_txn, *args, **kwargs)) return defer.succeed(func(self.mock_txn, *args, **kwargs))
self.db_pool.runInteraction = runInteraction conn_pool.runInteraction = runInteraction
def runWithConnection(func, *args, **kwargs): # type: ignore[no-untyped-def] def runWithConnection(func, *args, **kwargs): # type: ignore[no-untyped-def]
return defer.succeed(func(self.mock_conn, *args, **kwargs)) return defer.succeed(func(self.mock_conn, *args, **kwargs))
self.db_pool.runWithConnection = runWithConnection conn_pool.runWithConnection = runWithConnection
config = default_config(name="test", parse=True) config = default_config(name="test", parse=True)
hs = TestHomeServer("test", config=config) hs = TestHomeServer("test", config=config)
sqlite_config = {"name": "sqlite3"} if USE_POSTGRES_FOR_TESTS:
engine = create_engine(sqlite_config) db_config = {"name": "psycopg2", "args": {}}
else:
db_config = {"name": "sqlite3"}
engine = create_engine(db_config)
fake_engine = Mock(wraps=engine) fake_engine = Mock(wraps=engine)
fake_engine.in_transaction.return_value = False fake_engine.in_transaction.return_value = False
fake_engine.module.OperationalError = engine.module.OperationalError
fake_engine.module.DatabaseError = engine.module.DatabaseError
fake_engine.module.IntegrityError = engine.module.IntegrityError
# Don't convert param style to make assertions easier.
fake_engine.convert_param_style = lambda sql: sql
# To fix isinstance(...) checks.
fake_engine.__class__ = engine.__class__ # type: ignore[assignment]
db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine) db = DatabasePool(Mock(), Mock(config=db_config), fake_engine)
db._db_pool = self.db_pool db._db_pool = conn_pool
self.datastore = SQLBaseStore(db, None, hs) # type: ignore[arg-type] self.datastore = SQLBaseStore(db, None, hs) # type: ignore[arg-type]
def tearDown(self) -> None:
if USE_POSTGRES_FOR_TESTS:
self.execute_batch_patcher.stop()
self.execute_values_patcher.stop()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_insert_1col(self) -> Generator["defer.Deferred[object]", object, None]: def test_insert_1col(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
@ -71,7 +118,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
) )
) )
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_once_with(
"INSERT INTO tablename (columname) VALUES(?)", ("Value",) "INSERT INTO tablename (columname) VALUES(?)", ("Value",)
) )
@ -87,10 +134,65 @@ class SQLBaseStoreTestCase(unittest.TestCase):
) )
) )
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_once_with(
"INSERT INTO tablename (colA, colB, colC) VALUES(?, ?, ?)", (1, 2, 3) "INSERT INTO tablename (colA, colB, colC) VALUES(?, ?, ?)", (1, 2, 3)
) )
@defer.inlineCallbacks
def test_insert_many(self) -> Generator["defer.Deferred[object]", object, None]:
yield defer.ensureDeferred(
self.datastore.db_pool.simple_insert_many(
table="tablename",
keys=(
"col1",
"col2",
),
values=[
(
"val1",
"val2",
),
("val3", "val4"),
],
desc="",
)
)
if USE_POSTGRES_FOR_TESTS:
self.mock_execute_values.assert_called_once_with(
self.mock_txn,
"INSERT INTO tablename (col1, col2) VALUES ?",
[("val1", "val2"), ("val3", "val4")],
template=None,
fetch=False,
)
else:
self.mock_txn.executemany.assert_called_once_with(
"INSERT INTO tablename (col1, col2) VALUES(?, ?)",
[("val1", "val2"), ("val3", "val4")],
)
@defer.inlineCallbacks
def test_insert_many_no_iterable(
self,
) -> Generator["defer.Deferred[object]", object, None]:
yield defer.ensureDeferred(
self.datastore.db_pool.simple_insert_many(
table="tablename",
keys=(
"col1",
"col2",
),
values=[],
desc="",
)
)
if USE_POSTGRES_FOR_TESTS:
self.mock_execute_values.assert_not_called()
else:
self.mock_txn.executemany.assert_not_called()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_select_one_1col(self) -> Generator["defer.Deferred[object]", object, None]: def test_select_one_1col(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
@ -103,7 +205,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
) )
self.assertEqual("Value", value) self.assertEqual("Value", value)
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_once_with(
"SELECT retcol FROM tablename WHERE keycol = ?", ["TheKey"] "SELECT retcol FROM tablename WHERE keycol = ?", ["TheKey"]
) )
@ -121,7 +223,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
) )
self.assertEqual({"colA": 1, "colB": 2, "colC": 3}, ret) self.assertEqual({"colA": 1, "colB": 2, "colC": 3}, ret)
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_once_with(
"SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"] "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"]
) )
@ -156,10 +258,58 @@ class SQLBaseStoreTestCase(unittest.TestCase):
) )
self.assertEqual([(1,), (2,), (3,)], ret) self.assertEqual([(1,), (2,), (3,)], ret)
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_once_with(
"SELECT colA FROM tablename WHERE keycol = ?", ["A set"] "SELECT colA FROM tablename WHERE keycol = ?", ["A set"]
) )
@defer.inlineCallbacks
def test_select_many_batch(
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 3
self.mock_txn.fetchall.side_effect = [[(1,), (2,)], [(3,)]]
ret = yield defer.ensureDeferred(
self.datastore.db_pool.simple_select_many_batch(
table="tablename",
column="col1",
iterable=("val1", "val2", "val3"),
retcols=("col2",),
keyvalues={"col3": "val4"},
batch_size=2,
)
)
self.mock_txn.execute.assert_has_calls(
[
call(
"SELECT col2 FROM tablename WHERE col1 = ANY(?) AND col3 = ?",
[["val1", "val2"], "val4"],
),
call(
"SELECT col2 FROM tablename WHERE col1 = ANY(?) AND col3 = ?",
[["val3"], "val4"],
),
],
)
self.assertEqual([(1,), (2,), (3,)], ret)
def test_select_many_no_iterable(self) -> None:
self.mock_txn.rowcount = 3
self.mock_txn.fetchall.side_effect = [(1,), (2,)]
ret = self.datastore.db_pool.simple_select_many_txn(
self.mock_txn,
table="tablename",
column="col1",
iterable=(),
retcols=("col2",),
keyvalues={"col3": "val4"},
)
self.mock_txn.execute.assert_not_called()
self.assertEqual([], ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_update_one_1col(self) -> Generator["defer.Deferred[object]", object, None]: def test_update_one_1col(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
@ -172,7 +322,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
) )
) )
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_once_with(
"UPDATE tablename SET columnname = ? WHERE keycol = ?", "UPDATE tablename SET columnname = ? WHERE keycol = ?",
["New Value", "TheKey"], ["New Value", "TheKey"],
) )
@ -191,11 +341,69 @@ class SQLBaseStoreTestCase(unittest.TestCase):
) )
) )
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_once_with(
"UPDATE tablename SET colC = ?, colD = ? WHERE" " colA = ? AND colB = ?", "UPDATE tablename SET colC = ?, colD = ? WHERE" " colA = ? AND colB = ?",
[3, 4, 1, 2], [3, 4, 1, 2],
) )
@defer.inlineCallbacks
def test_update_many(self) -> Generator["defer.Deferred[object]", object, None]:
yield defer.ensureDeferred(
self.datastore.db_pool.simple_update_many(
table="tablename",
key_names=("col1", "col2"),
key_values=[("val1", "val2")],
value_names=("col3",),
value_values=[("val3",)],
desc="",
)
)
if USE_POSTGRES_FOR_TESTS:
self.mock_execute_batch.assert_called_once_with(
self.mock_txn,
"UPDATE tablename SET col3 = ? WHERE col1 = ? AND col2 = ?",
[("val3", "val1", "val2")],
)
else:
self.mock_txn.executemany.assert_called_once_with(
"UPDATE tablename SET col3 = ? WHERE col1 = ? AND col2 = ?",
[("val3", "val1", "val2")],
)
# key_values and value_values must be the same length.
with self.assertRaises(ValueError):
yield defer.ensureDeferred(
self.datastore.db_pool.simple_update_many(
table="tablename",
key_names=("col1", "col2"),
key_values=[("val1", "val2")],
value_names=("col3",),
value_values=[],
desc="",
)
)
@defer.inlineCallbacks
def test_update_many_no_iterable(
self,
) -> Generator["defer.Deferred[object]", object, None]:
yield defer.ensureDeferred(
self.datastore.db_pool.simple_update_many(
table="tablename",
key_names=("col1", "col2"),
key_values=[],
value_names=("col3",),
value_values=[],
desc="",
)
)
if USE_POSTGRES_FOR_TESTS:
self.mock_execute_batch.assert_not_called()
else:
self.mock_txn.executemany.assert_not_called()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_delete_one(self) -> Generator["defer.Deferred[object]", object, None]: def test_delete_one(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
@ -206,6 +414,393 @@ class SQLBaseStoreTestCase(unittest.TestCase):
) )
) )
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_once_with(
"DELETE FROM tablename WHERE keycol = ?", ["Go away"] "DELETE FROM tablename WHERE keycol = ?", ["Go away"]
) )
@defer.inlineCallbacks
def test_delete_many(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 2
result = yield defer.ensureDeferred(
self.datastore.db_pool.simple_delete_many(
table="tablename",
column="col1",
iterable=("val1", "val2"),
keyvalues={"col2": "val3"},
desc="",
)
)
self.mock_txn.execute.assert_called_once_with(
"DELETE FROM tablename WHERE col1 = ANY(?) AND col2 = ?",
[["val1", "val2"], "val3"],
)
self.assertEqual(result, 2)
@defer.inlineCallbacks
def test_delete_many_no_iterable(
self,
) -> Generator["defer.Deferred[object]", object, None]:
result = yield defer.ensureDeferred(
self.datastore.db_pool.simple_delete_many(
table="tablename",
column="col1",
iterable=(),
keyvalues={"col2": "val3"},
desc="",
)
)
self.mock_txn.execute.assert_not_called()
self.assertEqual(result, 0)
@defer.inlineCallbacks
def test_delete_many_no_keyvalues(
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 2
result = yield defer.ensureDeferred(
self.datastore.db_pool.simple_delete_many(
table="tablename",
column="col1",
iterable=("val1", "val2"),
keyvalues={},
desc="",
)
)
self.mock_txn.execute.assert_called_once_with(
"DELETE FROM tablename WHERE col1 = ANY(?)", [["val1", "val2"]]
)
self.assertEqual(result, 2)
@defer.inlineCallbacks
def test_upsert(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
result = yield defer.ensureDeferred(
self.datastore.db_pool.simple_upsert(
table="tablename",
keyvalues={"columnname": "oldvalue"},
values={"othercol": "newvalue"},
)
)
self.mock_txn.execute.assert_called_once_with(
"INSERT INTO tablename (columnname, othercol) VALUES (?, ?) ON CONFLICT (columnname) DO UPDATE SET othercol=EXCLUDED.othercol",
["oldvalue", "newvalue"],
)
self.assertTrue(result)
@defer.inlineCallbacks
def test_upsert_no_values(
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
result = yield defer.ensureDeferred(
self.datastore.db_pool.simple_upsert(
table="tablename",
keyvalues={"columnname": "value"},
values={},
insertion_values={"columnname": "value"},
)
)
self.mock_txn.execute.assert_called_once_with(
"INSERT INTO tablename (columnname) VALUES (?) ON CONFLICT (columnname) DO NOTHING",
["value"],
)
self.assertTrue(result)
@defer.inlineCallbacks
def test_upsert_with_insertion(
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
result = yield defer.ensureDeferred(
self.datastore.db_pool.simple_upsert(
table="tablename",
keyvalues={"columnname": "oldvalue"},
values={"othercol": "newvalue"},
insertion_values={"thirdcol": "insertionval"},
)
)
self.mock_txn.execute.assert_called_once_with(
"INSERT INTO tablename (columnname, thirdcol, othercol) VALUES (?, ?, ?) ON CONFLICT (columnname) DO UPDATE SET othercol=EXCLUDED.othercol",
["oldvalue", "insertionval", "newvalue"],
)
self.assertTrue(result)
@defer.inlineCallbacks
def test_upsert_with_where(
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
result = yield defer.ensureDeferred(
self.datastore.db_pool.simple_upsert(
table="tablename",
keyvalues={"columnname": "oldvalue"},
values={"othercol": "newvalue"},
where_clause="thirdcol IS NULL",
)
)
self.mock_txn.execute.assert_called_once_with(
"INSERT INTO tablename (columnname, othercol) VALUES (?, ?) ON CONFLICT (columnname) WHERE thirdcol IS NULL DO UPDATE SET othercol=EXCLUDED.othercol",
["oldvalue", "newvalue"],
)
self.assertTrue(result)
@defer.inlineCallbacks
def test_upsert_many(self) -> Generator["defer.Deferred[object]", object, None]:
yield defer.ensureDeferred(
self.datastore.db_pool.simple_upsert_many(
table="tablename",
key_names=["keycol1", "keycol2"],
key_values=[["keyval1", "keyval2"], ["keyval3", "keyval4"]],
value_names=["valuecol3"],
value_values=[["val5"], ["val6"]],
desc="",
)
)
if USE_POSTGRES_FOR_TESTS:
self.mock_execute_values.assert_called_once_with(
self.mock_txn,
"INSERT INTO tablename (keycol1, keycol2, valuecol3) VALUES ? ON CONFLICT (keycol1, keycol2) DO UPDATE SET valuecol3=EXCLUDED.valuecol3",
[("keyval1", "keyval2", "val5"), ("keyval3", "keyval4", "val6")],
template=None,
fetch=False,
)
else:
self.mock_txn.executemany.assert_called_once_with(
"INSERT INTO tablename (keycol1, keycol2, valuecol3) VALUES (?, ?, ?) ON CONFLICT (keycol1, keycol2) DO UPDATE SET valuecol3=EXCLUDED.valuecol3",
[("keyval1", "keyval2", "val5"), ("keyval3", "keyval4", "val6")],
)
@defer.inlineCallbacks
def test_upsert_many_no_values(
self,
) -> Generator["defer.Deferred[object]", object, None]:
yield defer.ensureDeferred(
self.datastore.db_pool.simple_upsert_many(
table="tablename",
key_names=["columnname"],
key_values=[["oldvalue"]],
value_names=[],
value_values=[],
desc="",
)
)
if USE_POSTGRES_FOR_TESTS:
self.mock_execute_values.assert_called_once_with(
self.mock_txn,
"INSERT INTO tablename (columnname) VALUES ? ON CONFLICT (columnname) DO NOTHING",
[("oldvalue",)],
template=None,
fetch=False,
)
else:
self.mock_txn.executemany.assert_called_once_with(
"INSERT INTO tablename (columnname) VALUES (?) ON CONFLICT (columnname) DO NOTHING",
[("oldvalue",)],
)
@defer.inlineCallbacks
def test_upsert_emulated_no_values_exists(
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.datastore.db_pool._unsafe_to_upsert_tables.add("tablename")
self.mock_txn.fetchall.return_value = [(1,)]
result = yield defer.ensureDeferred(
self.datastore.db_pool.simple_upsert(
table="tablename",
keyvalues={"columnname": "value"},
values={},
insertion_values={"columnname": "value"},
)
)
if USE_POSTGRES_FOR_TESTS:
self.mock_txn.execute.assert_has_calls(
[
call("LOCK TABLE tablename in EXCLUSIVE MODE", ()),
call("SELECT 1 FROM tablename WHERE columnname = ?", ["value"]),
]
)
else:
self.mock_txn.execute.assert_called_once_with(
"SELECT 1 FROM tablename WHERE columnname = ?", ["value"]
)
self.assertFalse(result)
@defer.inlineCallbacks
def test_upsert_emulated_no_values_not_exists(
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.datastore.db_pool._unsafe_to_upsert_tables.add("tablename")
self.mock_txn.fetchall.return_value = []
self.mock_txn.rowcount = 1
result = yield defer.ensureDeferred(
self.datastore.db_pool.simple_upsert(
table="tablename",
keyvalues={"columnname": "value"},
values={},
insertion_values={"columnname": "value"},
)
)
self.mock_txn.execute.assert_has_calls(
[
call(
"SELECT 1 FROM tablename WHERE columnname = ?",
["value"],
),
call("INSERT INTO tablename (columnname) VALUES (?)", ["value"]),
],
)
self.assertTrue(result)
@defer.inlineCallbacks
def test_upsert_emulated_with_insertion_exists(
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.datastore.db_pool._unsafe_to_upsert_tables.add("tablename")
self.mock_txn.rowcount = 1
result = yield defer.ensureDeferred(
self.datastore.db_pool.simple_upsert(
table="tablename",
keyvalues={"columnname": "oldvalue"},
values={"othercol": "newvalue"},
insertion_values={"thirdcol": "insertionval"},
)
)
if USE_POSTGRES_FOR_TESTS:
self.mock_txn.execute.assert_has_calls(
[
call("LOCK TABLE tablename in EXCLUSIVE MODE", ()),
call(
"UPDATE tablename SET othercol = ? WHERE columnname = ?",
["newvalue", "oldvalue"],
),
]
)
else:
self.mock_txn.execute.assert_called_once_with(
"UPDATE tablename SET othercol = ? WHERE columnname = ?",
["newvalue", "oldvalue"],
)
self.assertTrue(result)
@defer.inlineCallbacks
def test_upsert_emulated_with_insertion_not_exists(
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.datastore.db_pool._unsafe_to_upsert_tables.add("tablename")
self.mock_txn.rowcount = 0
result = yield defer.ensureDeferred(
self.datastore.db_pool.simple_upsert(
table="tablename",
keyvalues={"columnname": "oldvalue"},
values={"othercol": "newvalue"},
insertion_values={"thirdcol": "insertionval"},
)
)
self.mock_txn.execute.assert_has_calls(
[
call(
"UPDATE tablename SET othercol = ? WHERE columnname = ?",
["newvalue", "oldvalue"],
),
call(
"INSERT INTO tablename (columnname, othercol, thirdcol) VALUES (?, ?, ?)",
["oldvalue", "newvalue", "insertionval"],
),
]
)
self.assertTrue(result)
@defer.inlineCallbacks
def test_upsert_emulated_with_where(
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.datastore.db_pool._unsafe_to_upsert_tables.add("tablename")
self.mock_txn.rowcount = 1
result = yield defer.ensureDeferred(
self.datastore.db_pool.simple_upsert(
table="tablename",
keyvalues={"columnname": "oldvalue"},
values={"othercol": "newvalue"},
where_clause="thirdcol IS NULL",
)
)
if USE_POSTGRES_FOR_TESTS:
self.mock_txn.execute.assert_has_calls(
[
call("LOCK TABLE tablename in EXCLUSIVE MODE", ()),
call(
"UPDATE tablename SET othercol = ? WHERE columnname = ? AND thirdcol IS NULL",
["newvalue", "oldvalue"],
),
]
)
else:
self.mock_txn.execute.assert_called_once_with(
"UPDATE tablename SET othercol = ? WHERE columnname = ? AND thirdcol IS NULL",
["newvalue", "oldvalue"],
)
self.assertTrue(result)
@defer.inlineCallbacks
def test_upsert_emulated_with_where_no_values(
self,
) -> Generator["defer.Deferred[object]", object, None]:
self.datastore.db_pool._unsafe_to_upsert_tables.add("tablename")
self.mock_txn.rowcount = 1
result = yield defer.ensureDeferred(
self.datastore.db_pool.simple_upsert(
table="tablename",
keyvalues={"columnname": "oldvalue"},
values={},
where_clause="thirdcol IS NULL",
)
)
if USE_POSTGRES_FOR_TESTS:
self.mock_txn.execute.assert_has_calls(
[
call("LOCK TABLE tablename in EXCLUSIVE MODE", ()),
call(
"SELECT 1 FROM tablename WHERE columnname = ? AND thirdcol IS NULL",
["oldvalue"],
),
]
)
else:
self.mock_txn.execute.assert_called_once_with(
"SELECT 1 FROM tablename WHERE columnname = ? AND thirdcol IS NULL",
["oldvalue"],
)
self.assertFalse(result)