Set our own stream position from the current sequence value on startup (#17309)

This commit is contained in:
Quentin Gliech 2024-06-17 13:50:00 +02:00 committed by GitHub
parent 12d7303707
commit f983a77ab0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 147 additions and 178 deletions

1
changelog.d/17309.misc Normal file
View file

@ -0,0 +1 @@
When rolling back to a previous Synapse version and then forwards again to this release, don't require server operators to manually run SQL.

View file

@ -276,9 +276,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# no active writes in progress. # no active writes in progress.
self._max_position_of_local_instance = self._max_seen_allocated_stream_id self._max_position_of_local_instance = self._max_seen_allocated_stream_id
# This goes and fills out the above state from the database.
self._load_current_ids(db_conn, tables)
self._sequence_gen = build_sequence_generator( self._sequence_gen = build_sequence_generator(
db_conn=db_conn, db_conn=db_conn,
database_engine=db.engine, database_engine=db.engine,
@ -303,6 +300,13 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
positive=positive, positive=positive,
) )
# This goes and fills out the above state from the database.
# This may read on the PostgreSQL sequence, and
# SequenceGenerator.check_consistency might have fixed up the sequence, which
# means the SequenceGenerator needs to be setup before we read the value from
# the sequence.
self._load_current_ids(db_conn, tables, sequence_name)
self._max_seen_allocated_stream_id = max( self._max_seen_allocated_stream_id = max(
self._current_positions.values(), default=1 self._current_positions.values(), default=1
) )
@ -327,6 +331,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self, self,
db_conn: LoggingDatabaseConnection, db_conn: LoggingDatabaseConnection,
tables: List[Tuple[str, str, str]], tables: List[Tuple[str, str, str]],
sequence_name: str,
) -> None: ) -> None:
cur = db_conn.cursor(txn_name="_load_current_ids") cur = db_conn.cursor(txn_name="_load_current_ids")
@ -360,6 +365,18 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
if instance in self._writers if instance in self._writers
} }
# If we're a writer, we can assume we're at the end of the stream
# Usually, we would get that from the stream_positions, but in some cases,
# like if we rolled back Synapse, the stream_positions table might not be up to
# date. If we're using Postgres for the sequences, we can just use the current
# sequence value as our own position.
if self._instance_name in self._writers:
if isinstance(self._db.engine, PostgresEngine):
cur.execute(f"SELECT last_value FROM {sequence_name}")
row = cur.fetchone()
assert row is not None
self._current_positions[self._instance_name] = row[0]
# We set the `_persisted_upto_position` to be the minimum of all current # We set the `_persisted_upto_position` to be the minimum of all current
# positions. If empty we use the max stream ID from the DB table. # positions. If empty we use the max stream ID from the DB table.
min_stream_id = min(self._current_positions.values(), default=None) min_stream_id = min(self._current_positions.values(), default=None)

View file

@ -18,7 +18,7 @@
# [This file includes modifications made by New Vector Limited] # [This file includes modifications made by New Vector Limited]
# #
# #
from typing import List, Optional from typing import Dict, List, Optional
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -42,9 +42,13 @@ from tests.utils import USE_POSTGRES_FOR_TESTS
class MultiWriterIdGeneratorBase(HomeserverTestCase): class MultiWriterIdGeneratorBase(HomeserverTestCase):
positive: bool = True
tables: List[str] = ["foobar"]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool self.db_pool: DatabasePool = self.store.db_pool
self.instances: Dict[str, MultiWriterIdGenerator] = {}
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
@ -57,18 +61,22 @@ class MultiWriterIdGeneratorBase(HomeserverTestCase):
if USE_POSTGRES_FOR_TESTS: if USE_POSTGRES_FOR_TESTS:
txn.execute("CREATE SEQUENCE foobar_seq") txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute( for table in self.tables:
""" txn.execute(
CREATE TABLE foobar ( """
stream_id BIGINT NOT NULL, CREATE TABLE %s (
instance_name TEXT NOT NULL, stream_id BIGINT NOT NULL,
data TEXT instance_name TEXT NOT NULL,
); data TEXT
""" );
) """
% (table,)
)
def _create_id_generator( def _create_id_generator(
self, instance_name: str = "master", writers: Optional[List[str]] = None self,
instance_name: str = "master",
writers: Optional[List[str]] = None,
) -> MultiWriterIdGenerator: ) -> MultiWriterIdGenerator:
def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
return MultiWriterIdGenerator( return MultiWriterIdGenerator(
@ -77,36 +85,93 @@ class MultiWriterIdGeneratorBase(HomeserverTestCase):
notifier=self.hs.get_replication_notifier(), notifier=self.hs.get_replication_notifier(),
stream_name="test_stream", stream_name="test_stream",
instance_name=instance_name, instance_name=instance_name,
tables=[("foobar", "instance_name", "stream_id")], tables=[(table, "instance_name", "stream_id") for table in self.tables],
sequence_name="foobar_seq", sequence_name="foobar_seq",
writers=writers or ["master"], writers=writers or ["master"],
positive=self.positive,
) )
return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) self.instances[instance_name] = self.get_success_or_raise(
self.db_pool.runWithConnection(_create)
)
return self.instances[instance_name]
def _insert_rows(self, instance_name: str, number: int) -> None: def _replicate(self, instance_name: str) -> None:
"""Similate a replication event for the given instance."""
writer = self.instances[instance_name]
token = writer.get_current_token_for_writer(instance_name)
for generator in self.instances.values():
if writer != generator:
generator.advance(instance_name, token)
def _replicate_all(self) -> None:
"""Similate a replication event for all instances."""
for instance_name in self.instances:
self._replicate(instance_name)
def _insert_row(
self, instance_name: str, stream_id: int, table: Optional[str] = None
) -> None:
"""Insert one row as the given instance with given stream_id."""
if table is None:
table = self.tables[0]
factor = 1 if self.positive else -1
def _insert(txn: LoggingTransaction) -> None:
txn.execute(
"INSERT INTO %s VALUES (?, ?)" % (table,),
(
stream_id,
instance_name,
),
)
txn.execute(
"""
INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
""",
(instance_name, stream_id * factor, stream_id * factor),
)
self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
def _insert_rows(
self,
instance_name: str,
number: int,
table: Optional[str] = None,
update_stream_table: bool = True,
) -> None:
"""Insert N rows as the given instance, inserting with stream IDs pulled """Insert N rows as the given instance, inserting with stream IDs pulled
from the postgres sequence. from the postgres sequence.
""" """
if table is None:
table = self.tables[0]
factor = 1 if self.positive else -1
def _insert(txn: LoggingTransaction) -> None: def _insert(txn: LoggingTransaction) -> None:
for _ in range(number): for _ in range(number):
next_val = self.seq_gen.get_next_id_txn(txn) next_val = self.seq_gen.get_next_id_txn(txn)
txn.execute( txn.execute(
"INSERT INTO foobar (stream_id, instance_name) VALUES (?, ?)", "INSERT INTO %s (stream_id, instance_name) VALUES (?, ?)"
( % (table,),
next_val, (next_val, instance_name),
instance_name,
),
) )
txn.execute( if update_stream_table:
""" txn.execute(
INSERT INTO stream_positions VALUES ('test_stream', ?, ?) """
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ? INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
""", ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
(instance_name, next_val, next_val), """,
) (instance_name, next_val * factor, next_val * factor),
)
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
@ -353,7 +418,9 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
id_gen = self._create_id_generator("first", writers=["first", "second"]) id_gen = self._create_id_generator("first", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) # When the writer is created, it assumes its own position is the current head of
# the sequence
self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5})
self.assertEqual(id_gen.get_persisted_upto_position(), 5) self.assertEqual(id_gen.get_persisted_upto_position(), 5)
@ -375,11 +442,13 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
correctly. correctly.
""" """
self._insert_rows("first", 3) self._insert_rows("first", 3)
self._insert_rows("second", 4)
first_id_gen = self._create_id_generator("first", writers=["first", "second"]) first_id_gen = self._create_id_generator("first", writers=["first", "second"])
self._insert_rows("second", 4)
second_id_gen = self._create_id_generator("second", writers=["first", "second"]) second_id_gen = self._create_id_generator("second", writers=["first", "second"])
self._replicate_all()
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
@ -398,6 +467,9 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
self.assertEqual( self.assertEqual(
first_id_gen.get_positions(), {"first": 3, "second": 7} first_id_gen.get_positions(), {"first": 3, "second": 7}
) )
self.assertEqual(
second_id_gen.get_positions(), {"first": 3, "second": 7}
)
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
self.get_success(_get_next_async()) self.get_success(_get_next_async())
@ -432,11 +504,11 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
""" """
# Insert some rows for two out of three of the ID gens. # Insert some rows for two out of three of the ID gens.
self._insert_rows("first", 3) self._insert_rows("first", 3)
self._insert_rows("second", 4)
first_id_gen = self._create_id_generator( first_id_gen = self._create_id_generator(
"first", writers=["first", "second", "third"] "first", writers=["first", "second", "third"]
) )
self._insert_rows("second", 4)
second_id_gen = self._create_id_generator( second_id_gen = self._create_id_generator(
"second", writers=["first", "second", "third"] "second", writers=["first", "second", "third"]
) )
@ -444,6 +516,8 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
"third", writers=["first", "second", "third"] "third", writers=["first", "second", "third"]
) )
self._replicate_all()
self.assertEqual( self.assertEqual(
first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7} first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
) )
@ -546,11 +620,13 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
def test_minimal_local_token(self) -> None: def test_minimal_local_token(self) -> None:
self._insert_rows("first", 3) self._insert_rows("first", 3)
self._insert_rows("second", 4)
first_id_gen = self._create_id_generator("first", writers=["first", "second"]) first_id_gen = self._create_id_generator("first", writers=["first", "second"])
self._insert_rows("second", 4)
second_id_gen = self._create_id_generator("second", writers=["first", "second"]) second_id_gen = self._create_id_generator("second", writers=["first", "second"])
self._replicate_all()
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_minimal_local_current_token(), 3) self.assertEqual(first_id_gen.get_minimal_local_current_token(), 3)
@ -562,15 +638,17 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
token when there are no writes. token when there are no writes.
""" """
self._insert_rows("first", 3) self._insert_rows("first", 3)
self._insert_rows("second", 4)
first_id_gen = self._create_id_generator( first_id_gen = self._create_id_generator(
"first", writers=["first", "second", "third"] "first", writers=["first", "second", "third"]
) )
self._insert_rows("second", 4)
second_id_gen = self._create_id_generator( second_id_gen = self._create_id_generator(
"second", writers=["first", "second", "third"] "second", writers=["first", "second", "third"]
) )
self._replicate_all()
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7) self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(second_id_gen.get_current_token(), 7) self.assertEqual(second_id_gen.get_current_token(), 7)
@ -609,68 +687,13 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
self.assertEqual(second_id_gen.get_current_token(), 7) self.assertEqual(second_id_gen.get_current_token(), 7)
class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): class BackwardsMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs.""" """Tests MultiWriterIdGenerator that produce *negative* stream IDs."""
if not USE_POSTGRES_FOR_TESTS: if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres" skip = "Requires Postgres"
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: positive = False
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
def _setup_db(self, txn: LoggingTransaction) -> None:
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
CREATE TABLE foobar (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
data TEXT
);
"""
)
def _create_id_generator(
self, instance_name: str = "master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator:
def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
return MultiWriterIdGenerator(
conn,
self.db_pool,
notifier=self.hs.get_replication_notifier(),
stream_name="test_stream",
instance_name=instance_name,
tables=[("foobar", "instance_name", "stream_id")],
sequence_name="foobar_seq",
writers=writers or ["master"],
positive=False,
)
return self.get_success(self.db_pool.runWithConnection(_create))
def _insert_row(self, instance_name: str, stream_id: int) -> None:
"""Insert one row as the given instance with given stream_id."""
def _insert(txn: LoggingTransaction) -> None:
txn.execute(
"INSERT INTO foobar VALUES (?, ?)",
(
stream_id,
instance_name,
),
)
txn.execute(
"""
INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
""",
(instance_name, -stream_id, -stream_id),
)
self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
def test_single_instance(self) -> None: def test_single_instance(self) -> None:
"""Test that reads and writes from a single process are handled """Test that reads and writes from a single process are handled
@ -716,7 +739,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
async def _get_next_async() -> None: async def _get_next_async() -> None:
async with id_gen_1.get_next() as stream_id: async with id_gen_1.get_next() as stream_id:
self._insert_row("first", stream_id) self._insert_row("first", stream_id)
id_gen_2.advance("first", stream_id) self._replicate("first")
self.get_success(_get_next_async()) self.get_success(_get_next_async())
@ -728,7 +751,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
async def _get_next_async2() -> None: async def _get_next_async2() -> None:
async with id_gen_2.get_next() as stream_id: async with id_gen_2.get_next() as stream_id:
self._insert_row("second", stream_id) self._insert_row("second", stream_id)
id_gen_1.advance("second", stream_id) self._replicate("second")
self.get_success(_get_next_async2()) self.get_success(_get_next_async2())
@ -738,98 +761,26 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2) self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): class MultiTableMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
if not USE_POSTGRES_FOR_TESTS: if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres" skip = "Requires Postgres"
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: tables = ["foobar1", "foobar2"]
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
def _setup_db(self, txn: LoggingTransaction) -> None:
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
CREATE TABLE foobar1 (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
data TEXT
);
"""
)
txn.execute(
"""
CREATE TABLE foobar2 (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
data TEXT
);
"""
)
def _create_id_generator(
self, instance_name: str = "master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator:
def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
return MultiWriterIdGenerator(
conn,
self.db_pool,
notifier=self.hs.get_replication_notifier(),
stream_name="test_stream",
instance_name=instance_name,
tables=[
("foobar1", "instance_name", "stream_id"),
("foobar2", "instance_name", "stream_id"),
],
sequence_name="foobar_seq",
writers=writers or ["master"],
)
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
def _insert_rows(
self,
table: str,
instance_name: str,
number: int,
update_stream_table: bool = True,
) -> None:
"""Insert N rows as the given instance, inserting with stream IDs pulled
from the postgres sequence.
"""
def _insert(txn: LoggingTransaction) -> None:
for _ in range(number):
txn.execute(
"INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
(instance_name,),
)
if update_stream_table:
txn.execute(
"""
INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
""",
(instance_name,),
)
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
def test_load_existing_stream(self) -> None: def test_load_existing_stream(self) -> None:
"""Test creating ID gens with multiple tables that have rows from after """Test creating ID gens with multiple tables that have rows from after
the position in `stream_positions` table. the position in `stream_positions` table.
""" """
self._insert_rows("foobar1", "first", 3) self._insert_rows("first", 3, table="foobar1")
self._insert_rows("foobar2", "second", 3)
self._insert_rows("foobar2", "second", 1, update_stream_table=False)
first_id_gen = self._create_id_generator("first", writers=["first", "second"]) first_id_gen = self._create_id_generator("first", writers=["first", "second"])
self._insert_rows("second", 3, table="foobar2")
self._insert_rows("second", 1, table="foobar2", update_stream_table=False)
second_id_gen = self._create_id_generator("second", writers=["first", "second"]) second_id_gen = self._create_id_generator("second", writers=["first", "second"])
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6}) self._replicate_all()
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)