diff --git a/synapse/storage/database.py b/synapse/storage/database.py index cb4a5857be..17aacf6980 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1472,6 +1472,11 @@ class DatabasePool: key_values: Collection[Collection[Any]], value_names: Collection[str], value_values: Collection[Collection[Any]], + # Given these are the same type as the normal values, force keyword-only so that + # they can't be confused. + *, + insertion_value_names: Collection[str] = [], + insertion_value_values: Collection[Iterable[Any]] = [], desc: str, ) -> None: """ @@ -1497,6 +1502,8 @@ class DatabasePool: key_values, value_names, value_values, + insertion_value_names=insertion_value_names, + insertion_value_values=insertion_value_values, db_autocommit=autocommit, ) @@ -1508,6 +1515,11 @@ class DatabasePool: key_values: Collection[Iterable[Any]], value_names: Collection[str], value_values: Collection[Iterable[Any]], + # Given these are the same type as the normal values, force keyword-only so that + # they can't be confused. + *, + insertion_value_names: Collection[str] = [], + insertion_value_values: Collection[Iterable[Any]] = [], ) -> None: """ Upsert, many times. @@ -1519,6 +1531,9 @@ class DatabasePool: value_names: The value column names value_values: A list of each row's value column values. Ignored if value_names is empty. + insertion_value_names: The value column names to use only when inserting + insertion_value_values: A list of each row's value column values to use only + when inserting. Ignored if `insertion_value_names` is empty. """ # If there's nothing to upsert, then skip executing the query. if not key_values: @@ -1528,14 +1543,30 @@ class DatabasePool: # zip() works correctly. if not value_names: value_values = [() for x in range(len(key_values))] - elif len(value_values) != len(key_values): + elif len(key_values) != len(value_values): raise ValueError( f"{len(key_values)} key rows and {len(value_values)} value rows: should be the same number." ) + # No value columns, therefore make a blank list so that the following + # zip() works correctly. + if not insertion_value_names: + insertion_value_values = [() for x in range(len(key_values))] + elif len(key_values) != len(insertion_value_values): + raise ValueError( + f"{len(key_values)} key rows and {len(insertion_value_values)} insertion value rows: should be the same number." + ) + if table not in self._unsafe_to_upsert_tables: return self.simple_upsert_many_txn_native_upsert( - txn, table, key_names, key_values, value_names, value_values + txn, + table, + key_names, + key_values, + value_names, + value_values, + insertion_value_names=insertion_value_names, + insertion_value_values=insertion_value_values, ) else: return self.simple_upsert_many_txn_emulated( @@ -1545,6 +1576,9 @@ class DatabasePool: key_values, value_names, value_values, + # TODO + # insertion_value_names=insertion_value_names, + # insertion_value_values=insertion_value_values, ) def simple_upsert_many_txn_emulated( @@ -1555,6 +1589,11 @@ class DatabasePool: key_values: Collection[Iterable[Any]], value_names: Collection[str], value_values: Iterable[Iterable[Any]], + # Given these are the same type as the normal values, force keyword-only so that + # they can't be confused. + *, + insertion_value_names: Collection[str] = [], + insertion_value_values: Collection[Iterable[Any]] = [], ) -> None: """ Upsert, many times, but without native UPSERT support or batching. @@ -1566,6 +1605,9 @@ class DatabasePool: value_names: The value column names value_values: A list of each row's value column values. Ignored if value_names is empty. + insertion_value_names: The value column names to use only when inserting + insertion_value_values: A list of each row's value column values to use only + when inserting. Ignored if `insertion_value_names` is empty. """ # Lock the table just once, to prevent it being done once per row. @@ -1573,11 +1615,16 @@ class DatabasePool: # the lock is held for the remainder of the current transaction. self.engine.lock_table(txn, table) - for keyv, valv in zip(key_values, value_values): + for keyv, valv, insertionv in zip( + key_values, value_values, insertion_value_values + ): _keys = dict(zip(key_names, keyv)) _vals = dict(zip(value_names, valv)) + _insertion_vals = dict(zip(insertion_value_names, insertionv)) - self.simple_upsert_txn_emulated(txn, table, _keys, _vals, lock=False) + self.simple_upsert_txn_emulated( + txn, table, _keys, _vals, insertion_values=_insertion_vals, lock=False + ) @staticmethod def simple_upsert_many_txn_native_upsert( @@ -1587,6 +1634,11 @@ class DatabasePool: key_values: Collection[Iterable[Any]], value_names: Collection[str], value_values: Iterable[Iterable[Any]], + # Given these are the same type as the normal values, force keyword-only so that + # they can't be confused. + *, + insertion_value_names: Collection[str] = [], + insertion_value_values: Collection[Iterable[Any]] = [], ) -> None: """ Upsert, many times, using batching where possible. @@ -1598,10 +1650,14 @@ class DatabasePool: value_names: The value column names value_values: A list of each row's value column values. Ignored if value_names is empty. + insertion_value_names: The value column names to use only when inserting + insertion_value_values: A list of each row's value column values to use only + when inserting. Ignored if `insertion_value_names` is empty. """ allnames: List[str] = [] allnames.extend(key_names) allnames.extend(value_names) + allnames.extend(insertion_value_names) if not value_names: latter = "NOTHING" @@ -1612,8 +1668,8 @@ class DatabasePool: args = [] - for x, y in zip(key_values, value_values): - args.append(tuple(x) + tuple(y)) + for x, y, z in zip(key_values, value_values, insertion_value_values): + args.append(tuple(x) + tuple(y) + tuple(z)) if isinstance(txn.database_engine, PostgresEngine): # We use `execute_values` as it can be a lot faster than `execute_batch`, diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 5ed99516cf..06e0b53ca8 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -2500,7 +2500,8 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS table="sliding_sync_membership_snapshots", key_names=key_names, key_values=key_values, - # TODO: Implement these + value_names=(), + value_values=[], insertion_value_names=insertion_value_names, insertion_value_values=insertion_value_values, ) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 49dc973a36..9f87747652 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -31,6 +31,11 @@ from tests import unittest class UpdateUpsertManyTests(unittest.HomeserverTestCase): + """ + Integration tests for the "simple" SQL generating methods in SQLBaseStore that + actually run against a database. + """ + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.storage = hs.get_datastores().main @@ -130,6 +135,65 @@ class UpdateUpsertManyTests(unittest.HomeserverTestCase): {(1, "user1", "hello"), (2, "user2", "bleb")}, ) + def test_upsert_many_with_insertion_values(self) -> None: + """ + Upsert_many will only insert the + `insertion_value_names`/`insertion_value_values` (not on update/conflict) + """ + # Add some data to an empty table + key_names = ["id", "username"] + key_values = [[1, "user1"], [2, "user2"]] + value_names: List[str] = [] + value_values: List[List[str]] = [] + insertion_value_names = ["value"] + insertion_value_values = [["hello"], ["there"]] + + self.get_success( + self.storage.db_pool.runInteraction( + "test", + self.storage.db_pool.simple_upsert_many_txn, + self.table_name, + key_names, + key_values, + value_names, + value_values, + insertion_value_names=insertion_value_names, + insertion_value_values=insertion_value_values, + ) + ) + + # Check results are what we expect + self.assertEqual( + set(self._dump_table_to_tuple()), + {(1, "user1", "hello"), (2, "user2", "there")}, + ) + + # Update only user2 + key_values = [[2, "user2"]] + # Since this row already exists, when we try to insert it again, it should not + # insert the value again. + insertion_value_values = [["again"]] + + self.get_success( + self.storage.db_pool.runInteraction( + "test", + self.storage.db_pool.simple_upsert_many_txn, + self.table_name, + key_names, + key_values, + value_names, + value_values, + insertion_value_names=insertion_value_names, + insertion_value_values=insertion_value_values, + ) + ) + + # Check results are what we expect + self.assertEqual( + set(self._dump_table_to_tuple()), + {(1, "user1", "hello"), (2, "user2", "there")}, + ) + def test_simple_update_many(self) -> None: """ simple_update_many performs many updates at once. diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 9420d03841..33ba2b2456 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -35,7 +35,12 @@ from tests.utils import USE_POSTGRES_FOR_TESTS, default_config class SQLBaseStoreTestCase(unittest.TestCase): - """Test the "simple" SQL generating methods in SQLBaseStore.""" + """ + Test the "simple" SQL generating methods in SQLBaseStore. + + Tests that the SQL is generated correctly and that the correct arguments are passed + (does not actually run the queries or test the end-result in the database). + """ def setUp(self) -> None: # This is the Twisted connection pool. @@ -620,6 +625,74 @@ class SQLBaseStoreTestCase(unittest.TestCase): [("oldvalue",)], ) + @defer.inlineCallbacks + def test_upsert_many_no_values_with_insertion_values( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert_many( + table="tablename", + key_names=["keycol1"], + key_values=[["keyval1.1"], ["keyval1.2"]], + value_names=[], + value_values=[], + insertion_value_names=["insertioncol1"], + insertion_value_values=[["insertionvalue1"], ["insertionvalue2"]], + desc="", + ) + ) + + if USE_POSTGRES_FOR_TESTS: + self.mock_execute_values.assert_called_once_with( + self.mock_txn, + "INSERT INTO tablename (keycol1, insertioncol1) VALUES ? ON CONFLICT (keycol1) DO NOTHING", + [("keyval1.1", "insertionvalue1"), ("keyval1.2", "insertionvalue2")], + template=None, + fetch=False, + ) + else: + self.mock_txn.executemany.assert_called_once_with( + "INSERT INTO tablename (keycol1, insertioncol1) VALUES (?, ?) ON CONFLICT (keycol1) DO NOTHING", + [("keyval1.1", "insertionvalue1"), ("keyval1.2", "insertionvalue2")], + ) + + @defer.inlineCallbacks + def test_upsert_many_values_with_insertion_values( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + yield defer.ensureDeferred( + self.datastore.db_pool.simple_upsert_many( + table="tablename", + key_names=["keycol1"], + key_values=[["keyval1.1"], ["keyval1.2"]], + value_names=["valuecol1"], + value_values=[["value1"], ["value2"]], + insertion_value_names=["insertioncol1"], + insertion_value_values=[["insertionvalue1"], ["insertionvalue2"]], + desc="", + ) + ) + + if USE_POSTGRES_FOR_TESTS: + self.mock_execute_values.assert_called_once_with( + self.mock_txn, + "INSERT INTO tablename (keycol1, valuecol1, insertioncol1) VALUES ? ON CONFLICT (keycol1) DO UPDATE SET valuecol1=EXCLUDED.valuecol1", + [ + ("keyval1.1", "value1", "insertionvalue1"), + ("keyval1.2", "value2", "insertionvalue2"), + ], + template=None, + fetch=False, + ) + else: + self.mock_txn.executemany.assert_called_once_with( + "INSERT INTO tablename (keycol1, valuecol1, insertioncol1) VALUES (?, ?, ?) ON CONFLICT (keycol1) DO UPDATE SET valuecol1=EXCLUDED.valuecol1", + [ + ("keyval1.1", "value1", "insertionvalue1"), + ("keyval1.2", "value2", "insertionvalue2"), + ], + ) + @defer.inlineCallbacks def test_upsert_emulated_no_values_exists( self,