Make MultiWriterIDGenerator work for streams that use negative stream IDs (#8203)

This is so that we can use it for the backfill events stream.
This commit is contained in:
Erik Johnston 2020-09-01 13:36:25 +01:00 committed by GitHub
parent 318245eaa6
commit bbb3c8641c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 134 additions and 11 deletions

1
changelog.d/8203.misc Normal file
View file

@ -0,0 +1 @@
Make `MultiWriterIDGenerator` work for streams that use negative values.

View file

@ -185,6 +185,8 @@ class MultiWriterIdGenerator:
id_column: Column that stores the stream ID. id_column: Column that stores the stream ID.
sequence_name: The name of the postgres sequence used to generate new sequence_name: The name of the postgres sequence used to generate new
IDs. IDs.
positive: Whether the IDs are positive (true) or negative (false).
When using negative IDs we go backwards from -1 to -2, -3, etc.
""" """
def __init__( def __init__(
@ -196,13 +198,19 @@ class MultiWriterIdGenerator:
instance_column: str, instance_column: str,
id_column: str, id_column: str,
sequence_name: str, sequence_name: str,
positive: bool = True,
): ):
self._db = db self._db = db
self._instance_name = instance_name self._instance_name = instance_name
self._positive = positive
self._return_factor = 1 if positive else -1
# We lock as some functions may be called from DB threads. # We lock as some functions may be called from DB threads.
self._lock = threading.Lock() self._lock = threading.Lock()
# Note: If we are a negative stream then we still store all the IDs as
# positive to make life easier for us, and simply negate the IDs when we
# return them.
self._current_positions = self._load_current_ids( self._current_positions = self._load_current_ids(
db_conn, table, instance_column, id_column db_conn, table, instance_column, id_column
) )
@ -233,13 +241,16 @@ class MultiWriterIdGenerator:
def _load_current_ids( def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str self, db_conn, table: str, instance_column: str, id_column: str
) -> Dict[str, int]: ) -> Dict[str, int]:
# If positive stream aggregate via MAX. For negative stream use MIN
# *and* negate the result to get a positive number.
sql = """ sql = """
SELECT %(instance)s, MAX(%(id)s) FROM %(table)s SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
GROUP BY %(instance)s GROUP BY %(instance)s
""" % { """ % {
"instance": instance_column, "instance": instance_column,
"id": id_column, "id": id_column,
"table": table, "table": table,
"agg": "MAX" if self._positive else "-MIN",
} }
cur = db_conn.cursor() cur = db_conn.cursor()
@ -269,15 +280,16 @@ class MultiWriterIdGenerator:
# Assert the fetched ID is actually greater than what we currently # Assert the fetched ID is actually greater than what we currently
# believe the ID to be. If not, then the sequence and table have got # believe the ID to be. If not, then the sequence and table have got
# out of sync somehow. # out of sync somehow.
assert self.get_current_token_for_writer(self._instance_name) < next_id
with self._lock: with self._lock:
assert self._current_positions.get(self._instance_name, 0) < next_id
self._unfinished_ids.add(next_id) self._unfinished_ids.add(next_id)
@contextlib.contextmanager @contextlib.contextmanager
def manager(): def manager():
try: try:
yield next_id # Multiply by the return factor so that the ID has correct sign.
yield self._return_factor * next_id
finally: finally:
self._mark_id_as_finished(next_id) self._mark_id_as_finished(next_id)
@ -296,15 +308,15 @@ class MultiWriterIdGenerator:
# Assert the fetched ID is actually greater than any ID we've already # Assert the fetched ID is actually greater than any ID we've already
# seen. If not, then the sequence and table have got out of sync # seen. If not, then the sequence and table have got out of sync
# somehow. # somehow.
assert max(self.get_positions().values(), default=0) < min(next_ids)
with self._lock: with self._lock:
assert max(self._current_positions.values(), default=0) < min(next_ids)
self._unfinished_ids.update(next_ids) self._unfinished_ids.update(next_ids)
@contextlib.contextmanager @contextlib.contextmanager
def manager(): def manager():
try: try:
yield next_ids yield [self._return_factor * i for i in next_ids]
finally: finally:
for i in next_ids: for i in next_ids:
self._mark_id_as_finished(i) self._mark_id_as_finished(i)
@ -327,7 +339,7 @@ class MultiWriterIdGenerator:
txn.call_after(self._mark_id_as_finished, next_id) txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id) txn.call_on_exception(self._mark_id_as_finished, next_id)
return next_id return self._return_factor * next_id
def _mark_id_as_finished(self, next_id: int): def _mark_id_as_finished(self, next_id: int):
"""The ID has finished being processed so we should advance the """The ID has finished being processed so we should advance the
@ -359,20 +371,25 @@ class MultiWriterIdGenerator:
""" """
with self._lock: with self._lock:
return self._current_positions.get(instance_name, 0) return self._return_factor * self._current_positions.get(instance_name, 0)
def get_positions(self) -> Dict[str, int]: def get_positions(self) -> Dict[str, int]:
"""Get a copy of the current positon map. """Get a copy of the current positon map.
""" """
with self._lock: with self._lock:
return dict(self._current_positions) return {
name: self._return_factor * i
for name, i in self._current_positions.items()
}
def advance(self, instance_name: str, new_id: int): def advance(self, instance_name: str, new_id: int):
"""Advance the postion of the named writer to the given ID, if greater """Advance the postion of the named writer to the given ID, if greater
than existing entry. than existing entry.
""" """
new_id *= self._return_factor
with self._lock: with self._lock:
self._current_positions[instance_name] = max( self._current_positions[instance_name] = max(
new_id, self._current_positions.get(instance_name, 0) new_id, self._current_positions.get(instance_name, 0)
@ -390,7 +407,7 @@ class MultiWriterIdGenerator:
""" """
with self._lock: with self._lock:
return self._persisted_upto_position return self._return_factor * self._persisted_upto_position
def _add_persisted_position(self, new_id: int): def _add_persisted_position(self, new_id: int):
"""Record that we have persisted a position. """Record that we have persisted a position.

View file

@ -264,3 +264,108 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# We assume that so long as `get_next` does correctly advance the # We assume that so long as `get_next` does correctly advance the
# `persisted_upto_position` in this case, then it will be correct in the # `persisted_upto_position` in this case, then it will be correct in the
# other cases that are tested above (since they'll hit the same code). # other cases that are tested above (since they'll hit the same code).
class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs.
"""
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.db_pool = self.store.db_pool # type: DatabasePool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
def _setup_db(self, txn):
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
CREATE TABLE foobar (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
data TEXT
);
"""
)
def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
conn,
self.db_pool,
instance_name=instance_name,
table="foobar",
instance_column="instance_name",
id_column="stream_id",
sequence_name="foobar_seq",
positive=False,
)
return self.get_success(self.db_pool.runWithConnection(_create))
def _insert_row(self, instance_name: str, stream_id: int):
"""Insert one row as the given instance with given stream_id.
"""
def _insert(txn):
txn.execute(
"INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
)
self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
def test_single_instance(self):
"""Test that reads and writes from a single process are handled
correctly.
"""
id_gen = self._create_id_generator()
with self.get_success(id_gen.get_next()) as stream_id:
self._insert_row("master", stream_id)
self.assertEqual(id_gen.get_positions(), {"master": -1})
self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
self.assertEqual(id_gen.get_persisted_upto_position(), -1)
with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
for stream_id in stream_ids:
self._insert_row("master", stream_id)
self.assertEqual(id_gen.get_positions(), {"master": -4})
self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
self.assertEqual(id_gen.get_persisted_upto_position(), -4)
# Test loading from DB by creating a second ID gen
second_id_gen = self._create_id_generator()
self.assertEqual(second_id_gen.get_positions(), {"master": -4})
self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)
def test_multiple_instance(self):
"""Tests that having multiple instances that get advanced over
federation works corretly.
"""
id_gen_1 = self._create_id_generator("first")
id_gen_2 = self._create_id_generator("second")
with self.get_success(id_gen_1.get_next()) as stream_id:
self._insert_row("first", stream_id)
id_gen_2.advance("first", stream_id)
self.assertEqual(id_gen_1.get_positions(), {"first": -1})
self.assertEqual(id_gen_2.get_positions(), {"first": -1})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
with self.get_success(id_gen_2.get_next()) as stream_id:
self._insert_row("second", stream_id)
id_gen_1.advance("second", stream_id)
self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)