Fix bug where a new writer advances their token too quickly (#16473)

* Fix bug where a new writer advances their token too quickly

When starting a new writer (for e.g. persisting events), the
`MultiWriterIdGenerator` doesn't have a minimum token for it as there
are no rows matching that new writer in the DB.

This results in the the first stream ID it acquired being announced as
persisted *before* it actually finishes persisting, if another writer
gets and persists a subsequent stream ID. This is due to the logic of
setting the minimum persisted position to the minimum known position of
across all writers, and the new writer starts off not being considered.

* Fix sending out POSITIONs when our token advances without update

Broke in #14820

* For replication HTTP requests, only wait for minimal position
This commit is contained in:
Erik Johnston 2023-10-23 16:57:30 +01:00 committed by GitHub
parent 3bc23cc45c
commit 8f35f8148e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 305 additions and 77 deletions

1
changelog.d/16473.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a long-standing, exceedingly rare edge case where the first event persisted by a new event persister worker might not be sent down `/sync`.

View file

@ -51,17 +51,24 @@ will be inserted with that ID.
For any given stream reader (including writers themselves), we may define a per-writer current stream ID: For any given stream reader (including writers themselves), we may define a per-writer current stream ID:
> The current stream ID _for a writer W_ is the largest stream ID such that > A current stream ID _for a writer W_ is the largest stream ID such that
> all transactions added by W with equal or smaller ID have completed. > all transactions added by W with equal or smaller ID have completed.
Similarly, there is a "linear" notion of current stream ID: Similarly, there is a "linear" notion of current stream ID:
> The "linear" current stream ID is the largest stream ID such that > A "linear" current stream ID is the largest stream ID such that
> all facts (added by any writer) with equal or smaller ID have completed. > all facts (added by any writer) with equal or smaller ID have completed.
Because different stream readers A and B learn about new facts at different times, A and B may disagree about current stream IDs. Because different stream readers A and B learn about new facts at different times, A and B may disagree about current stream IDs.
Put differently: we should think of stream readers as being independent of each other, proceeding through a stream of facts at different rates. Put differently: we should think of stream readers as being independent of each other, proceeding through a stream of facts at different rates.
The above definition does not give a unique current stream ID, in fact there can
be a range of current stream IDs. Synapse uses both the minimum and maximum IDs
for different purposes. Most often the maximum is used, as its generally
beneficial for workers to advance their IDs as soon as possible. However, the
minimum is used in situations where e.g. another worker is going to wait until
the stream advances past a position.
**NB.** For both senses of "current", that if a writer opens a transaction that never completes, the current stream ID will never advance beyond that writer's last written stream ID. **NB.** For both senses of "current", that if a writer opens a transaction that never completes, the current stream ID will never advance beyond that writer's last written stream ID.
For single-writer streams, the per-writer current ID and the linear current ID are the same. For single-writer streams, the per-writer current ID and the linear current ID are the same.

View file

@ -238,7 +238,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
data[_STREAM_POSITION_KEY] = { data[_STREAM_POSITION_KEY] = {
"streams": { "streams": {
stream.NAME: stream.current_token(local_instance_name) stream.NAME: stream.minimal_local_current_token()
for stream in streams for stream in streams
}, },
"instance_name": local_instance_name, "instance_name": local_instance_name,

View file

@ -33,6 +33,7 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -107,22 +108,10 @@ class Stream:
def __init__( def __init__(
self, self,
local_instance_name: str, local_instance_name: str,
current_token_function: Callable[[str], Token],
update_function: UpdateFunction, update_function: UpdateFunction,
): ):
"""Instantiate a Stream """Instantiate a Stream
`current_token_function` and `update_function` are callbacks which
should be implemented by subclasses.
`current_token_function` takes an instance name, which is a writer to
the stream, and returns the position in the stream of the writer (as
viewed from the current process). On the writer process this is where
the writer has successfully written up to, whereas on other processes
this is the position which we have received updates up to over
replication. (Note that most streams have a single writer and so their
implementations ignore the instance name passed in).
`update_function` is called to get updates for this stream between a `update_function` is called to get updates for this stream between a
pair of stream tokens. See the `UpdateFunction` type definition for more pair of stream tokens. See the `UpdateFunction` type definition for more
info. info.
@ -133,12 +122,28 @@ class Stream:
update_function: callback go get stream updates, as above update_function: callback go get stream updates, as above
""" """
self.local_instance_name = local_instance_name self.local_instance_name = local_instance_name
self.current_token = current_token_function
self.update_function = update_function self.update_function = update_function
# The token from which we last asked for updates # The token from which we last asked for updates
self.last_token = self.current_token(self.local_instance_name) self.last_token = self.current_token(self.local_instance_name)
def current_token(self, instance_name: str) -> Token:
"""This takes an instance name, which is a writer to
the stream, and returns the position in the stream of the writer (as
viewed from the current process).
"""
# We can't make this an abstract class as it makes mypy unhappy.
raise NotImplementedError()
def minimal_local_current_token(self) -> Token:
"""Tries to return a minimal current token for the local instance,
i.e. for writers this would be the last successful write.
If local instance is not a writer (or has written yet) then falls back
to returning the normal "current token".
"""
raise NotImplementedError()
def discard_updates_and_advance(self) -> None: def discard_updates_and_advance(self) -> None:
"""Called when the stream should advance but the updates would be discarded, """Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers. e.g. when there are no currently connected workers.
@ -190,6 +195,25 @@ class Stream:
return updates, upto_token, limited return updates, upto_token, limited
class _StreamFromIdGen(Stream):
"""Helper class for simple streams that use a stream ID generator"""
def __init__(
self,
local_instance_name: str,
update_function: UpdateFunction,
stream_id_gen: "AbstractStreamIdGenerator",
):
self._stream_id_gen = stream_id_gen
super().__init__(local_instance_name, update_function)
def current_token(self, instance_name: str) -> Token:
return self._stream_id_gen.get_current_token_for_writer(instance_name)
def minimal_local_current_token(self) -> Token:
return self._stream_id_gen.get_minimal_local_current_token()
def current_token_without_instance( def current_token_without_instance(
current_token: Callable[[], int] current_token: Callable[[], int]
) -> Callable[[str], int]: ) -> Callable[[str], int]:
@ -242,17 +266,21 @@ class BackfillStream(Stream):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
self._current_token,
self.store.get_all_new_backfill_event_rows, self.store.get_all_new_backfill_event_rows,
) )
def _current_token(self, instance_name: str) -> int: def current_token(self, instance_name: str) -> Token:
# The backfill stream over replication operates on *positive* numbers, # The backfill stream over replication operates on *positive* numbers,
# which means we need to negate it. # which means we need to negate it.
return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name) return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name)
def minimal_local_current_token(self) -> Token:
# The backfill stream over replication operates on *positive* numbers,
# which means we need to negate it.
return -self.store._backfill_id_gen.get_minimal_local_current_token()
class PresenceStream(Stream):
class PresenceStream(_StreamFromIdGen):
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class PresenceStreamRow: class PresenceStreamRow:
user_id: str user_id: str
@ -283,9 +311,7 @@ class PresenceStream(Stream):
update_function = make_http_update_function(hs, self.NAME) update_function = make_http_update_function(hs, self.NAME)
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(), update_function, store._presence_id_gen
current_token_without_instance(store.get_current_presence_token),
update_function,
) )
@ -305,13 +331,18 @@ class PresenceFederationStream(Stream):
ROW_TYPE = PresenceFederationStreamRow ROW_TYPE = PresenceFederationStreamRow
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
federation_queue = hs.get_presence_handler().get_federation_queue() self._federation_queue = hs.get_presence_handler().get_federation_queue()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
federation_queue.get_current_token, self._federation_queue.get_replication_rows,
federation_queue.get_replication_rows,
) )
def current_token(self, instance_name: str) -> Token:
return self._federation_queue.get_current_token(instance_name)
def minimal_local_current_token(self) -> Token:
return self._federation_queue.get_current_token(self.local_instance_name)
class TypingStream(Stream): class TypingStream(Stream):
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
@ -341,20 +372,25 @@ class TypingStream(Stream):
update_function: Callable[ update_function: Callable[
[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]] [str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]
] = typing_writer_handler.get_all_typing_updates ] = typing_writer_handler.get_all_typing_updates
current_token_function = typing_writer_handler.get_current_token self.current_token_function = typing_writer_handler.get_current_token
else: else:
# Query the typing writer process # Query the typing writer process
update_function = make_http_update_function(hs, self.NAME) update_function = make_http_update_function(hs, self.NAME)
current_token_function = hs.get_typing_handler().get_current_token self.current_token_function = hs.get_typing_handler().get_current_token
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
current_token_without_instance(current_token_function),
update_function, update_function,
) )
def current_token(self, instance_name: str) -> Token:
return self.current_token_function()
class ReceiptsStream(Stream): def minimal_local_current_token(self) -> Token:
return self.current_token_function()
class ReceiptsStream(_StreamFromIdGen):
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class ReceiptsStreamRow: class ReceiptsStreamRow:
room_id: str room_id: str
@ -371,12 +407,12 @@ class ReceiptsStream(Stream):
store = hs.get_datastores().main store = hs.get_datastores().main
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
current_token_without_instance(store.get_max_receipt_stream_id),
store.get_all_updated_receipts, store.get_all_updated_receipts,
store._receipts_id_gen,
) )
class PushRulesStream(Stream): class PushRulesStream(_StreamFromIdGen):
"""A user has changed their push rules""" """A user has changed their push rules"""
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
@ -387,20 +423,16 @@ class PushRulesStream(Stream):
ROW_TYPE = PushRulesStreamRow ROW_TYPE = PushRulesStreamRow
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main store = hs.get_datastores().main
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
self._current_token, store.get_all_push_rule_updates,
self.store.get_all_push_rule_updates, store._push_rules_stream_id_gen,
) )
def _current_token(self, instance_name: str) -> int:
push_rules_token = self.store.get_max_push_rules_stream_id()
return push_rules_token
class PushersStream(_StreamFromIdGen):
class PushersStream(Stream):
"""A user has added/changed/removed a pusher""" """A user has added/changed/removed a pusher"""
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
@ -418,8 +450,8 @@ class PushersStream(Stream):
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
current_token_without_instance(store.get_pushers_stream_token),
store.get_all_updated_pushers_rows, store.get_all_updated_pushers_rows,
store._pushers_id_gen,
) )
@ -447,15 +479,20 @@ class CachesStream(Stream):
ROW_TYPE = CachesStreamRow ROW_TYPE = CachesStreamRow
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
store = hs.get_datastores().main self.store = hs.get_datastores().main
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
store.get_cache_stream_token_for_writer, self.store.get_all_updated_caches,
store.get_all_updated_caches,
) )
def current_token(self, instance_name: str) -> Token:
return self.store.get_cache_stream_token_for_writer(instance_name)
class DeviceListsStream(Stream): def minimal_local_current_token(self) -> Token:
return self.current_token(self.local_instance_name)
class DeviceListsStream(_StreamFromIdGen):
"""Either a user has updated their devices or a remote server needs to be """Either a user has updated their devices or a remote server needs to be
told about a device update. told about a device update.
""" """
@ -473,8 +510,8 @@ class DeviceListsStream(Stream):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
current_token_without_instance(self.store.get_device_stream_token),
self._update_function, self._update_function,
self.store._device_list_id_gen,
) )
async def _update_function( async def _update_function(
@ -525,7 +562,7 @@ class DeviceListsStream(Stream):
return updates, upper_limit_token, devices_limited or signatures_limited return updates, upper_limit_token, devices_limited or signatures_limited
class ToDeviceStream(Stream): class ToDeviceStream(_StreamFromIdGen):
"""New to_device messages for a client""" """New to_device messages for a client"""
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
@ -539,12 +576,12 @@ class ToDeviceStream(Stream):
store = hs.get_datastores().main store = hs.get_datastores().main
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
current_token_without_instance(store.get_to_device_stream_token),
store.get_all_new_device_messages, store.get_all_new_device_messages,
store._device_inbox_id_gen,
) )
class AccountDataStream(Stream): class AccountDataStream(_StreamFromIdGen):
"""Global or per room account data was changed""" """Global or per room account data was changed"""
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
@ -560,8 +597,8 @@ class AccountDataStream(Stream):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
current_token_without_instance(self.store.get_max_account_data_stream_id),
self._update_function, self._update_function,
self.store._account_data_id_gen,
) )
async def _update_function( async def _update_function(

View file

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast
import attr import attr
from synapse.replication.tcp.streams._base import ( from synapse.replication.tcp.streams._base import (
Stream,
StreamRow, StreamRow,
StreamUpdateResult, StreamUpdateResult,
Token, Token,
_StreamFromIdGen,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@ -139,7 +139,7 @@ _EventRows: Tuple[Type[BaseEventsStreamRow], ...] = (
TypeToRow = {Row.TypeId: Row for Row in _EventRows} TypeToRow = {Row.TypeId: Row for Row in _EventRows}
class EventsStream(Stream): class EventsStream(_StreamFromIdGen):
"""We received a new event, or an event went from being an outlier to not""" """We received a new event, or an event went from being an outlier to not"""
NAME = "events" NAME = "events"
@ -147,9 +147,7 @@ class EventsStream(Stream):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastores().main self._store = hs.get_datastores().main
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(), self._update_function, self._store._stream_id_gen
self._store._stream_id_gen.get_current_token_for_writer,
self._update_function,
) )
async def _update_function( async def _update_function(

View file

@ -18,6 +18,7 @@ import attr
from synapse.replication.tcp.streams._base import ( from synapse.replication.tcp.streams._base import (
Stream, Stream,
Token,
current_token_without_instance, current_token_without_instance,
make_http_update_function, make_http_update_function,
) )
@ -47,7 +48,7 @@ class FederationStream(Stream):
# will be a real FederationSender, which has stubs for current_token and # will be a real FederationSender, which has stubs for current_token and
# get_replication_rows.) # get_replication_rows.)
federation_sender = hs.get_federation_sender() federation_sender = hs.get_federation_sender()
current_token = current_token_without_instance( self.current_token_func = current_token_without_instance(
federation_sender.get_current_token federation_sender.get_current_token
) )
update_function: Callable[ update_function: Callable[
@ -57,15 +58,21 @@ class FederationStream(Stream):
elif hs.should_send_federation(): elif hs.should_send_federation():
# federation sender: Query master process # federation sender: Query master process
update_function = make_http_update_function(hs, self.NAME) update_function = make_http_update_function(hs, self.NAME)
current_token = self._stub_current_token self.current_token_func = self._stub_current_token
else: else:
# other worker: stub out the update function (we're not interested in # other worker: stub out the update function (we're not interested in
# any updates so when we get a POSITION we do nothing) # any updates so when we get a POSITION we do nothing)
update_function = self._stub_update_function update_function = self._stub_update_function
current_token = self._stub_current_token self.current_token_func = self._stub_current_token
super().__init__(hs.get_instance_name(), current_token, update_function) super().__init__(hs.get_instance_name(), update_function)
def current_token(self, instance_name: str) -> Token:
return self.current_token_func(instance_name)
def minimal_local_current_token(self) -> Token:
return self.current_token(self.local_instance_name)
@staticmethod @staticmethod
def _stub_current_token(instance_name: str) -> int: def _stub_current_token(instance_name: str) -> int:

View file

@ -15,7 +15,7 @@ from typing import TYPE_CHECKING
import attr import attr
from synapse.replication.tcp.streams import Stream from synapse.replication.tcp.streams._base import _StreamFromIdGen
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -27,7 +27,7 @@ class UnPartialStatedRoomStreamRow:
room_id: str room_id: str
class UnPartialStatedRoomStream(Stream): class UnPartialStatedRoomStream(_StreamFromIdGen):
""" """
Stream to notify about rooms becoming un-partial-stated; Stream to notify about rooms becoming un-partial-stated;
that is, when the background sync finishes such that we now have full state for that is, when the background sync finishes such that we now have full state for
@ -41,8 +41,8 @@ class UnPartialStatedRoomStream(Stream):
store = hs.get_datastores().main store = hs.get_datastores().main
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
store.get_un_partial_stated_rooms_token,
store.get_un_partial_stated_rooms_from_stream, store.get_un_partial_stated_rooms_from_stream,
store._un_partial_stated_rooms_stream_id_gen,
) )
@ -56,7 +56,7 @@ class UnPartialStatedEventStreamRow:
rejection_status_changed: bool rejection_status_changed: bool
class UnPartialStatedEventStream(Stream): class UnPartialStatedEventStream(_StreamFromIdGen):
""" """
Stream to notify about events becoming un-partial-stated. Stream to notify about events becoming un-partial-stated.
""" """
@ -68,6 +68,6 @@ class UnPartialStatedEventStream(Stream):
store = hs.get_datastores().main store = hs.get_datastores().main
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
store.get_un_partial_stated_events_token,
store.get_un_partial_stated_events_from_stream, store.get_un_partial_stated_events_from_stream,
store._un_partial_stated_events_stream_id_gen,
) )

View file

@ -133,6 +133,15 @@ class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
""" """
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod
def get_minimal_local_current_token(self) -> int:
"""Tries to return a minimal current token for the local instance,
i.e. for writers this would be the last successful write.
If local instance is not a writer (or has written yet) then falls back
to returning the normal "current token".
"""
@abc.abstractmethod @abc.abstractmethod
def get_next(self) -> AsyncContextManager[int]: def get_next(self) -> AsyncContextManager[int]:
""" """
@ -312,6 +321,9 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
def get_current_token_for_writer(self, instance_name: str) -> int: def get_current_token_for_writer(self, instance_name: str) -> int:
return self.get_current_token() return self.get_current_token()
def get_minimal_local_current_token(self) -> int:
return self.get_current_token()
class MultiWriterIdGenerator(AbstractStreamIdGenerator): class MultiWriterIdGenerator(AbstractStreamIdGenerator):
"""Generates and tracks stream IDs for a stream with multiple writers. """Generates and tracks stream IDs for a stream with multiple writers.
@ -408,6 +420,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# The maximum stream ID that we have seen been allocated across any writer. # The maximum stream ID that we have seen been allocated across any writer.
self._max_seen_allocated_stream_id = 1 self._max_seen_allocated_stream_id = 1
# The maximum position of the local instance. This can be higher than
# the corresponding position in `current_positions` table when there are
# no active writes in progress.
self._max_position_of_local_instance = self._max_seen_allocated_stream_id
self._sequence_gen = PostgresSequenceGenerator(sequence_name) self._sequence_gen = PostgresSequenceGenerator(sequence_name)
# We check that the table and sequence haven't diverged. # We check that the table and sequence haven't diverged.
@ -427,6 +444,16 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._current_positions.values(), default=1 self._current_positions.values(), default=1
) )
# For the case where `stream_positions` is not up to date,
# `_persisted_upto_position` may be higher.
self._max_seen_allocated_stream_id = max(
self._max_seen_allocated_stream_id, self._persisted_upto_position
)
# Bump our local maximum position now that we've loaded things from the
# DB.
self._max_position_of_local_instance = self._max_seen_allocated_stream_id
if not writers: if not writers:
# If there have been no explicit writers given then any instance can # If there have been no explicit writers given then any instance can
# write to the stream. In which case, let's pre-seed our own # write to the stream. In which case, let's pre-seed our own
@ -545,6 +572,14 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
if instance == self._instance_name: if instance == self._instance_name:
self._current_positions[instance] = stream_id self._current_positions[instance] = stream_id
if self._writers:
# If we have explicit writers then make sure that each instance has
# a position.
for writer in self._writers:
self._current_positions.setdefault(
writer, self._persisted_upto_position
)
cur.close() cur.close()
def _load_next_id_txn(self, txn: Cursor) -> int: def _load_next_id_txn(self, txn: Cursor) -> int:
@ -688,6 +723,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
if new_cur: if new_cur:
curr = self._current_positions.get(self._instance_name, 0) curr = self._current_positions.get(self._instance_name, 0)
self._current_positions[self._instance_name] = max(curr, new_cur) self._current_positions[self._instance_name] = max(curr, new_cur)
self._max_position_of_local_instance = max(
curr, new_cur, self._max_position_of_local_instance
)
self._add_persisted_position(next_id) self._add_persisted_position(next_id)
@ -702,10 +740,26 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# persisted up to position. This stops Synapse from doing a full table # persisted up to position. This stops Synapse from doing a full table
# scan when a new writer announces itself over replication. # scan when a new writer announces itself over replication.
with self._lock: with self._lock:
return self._return_factor * self._current_positions.get( if self._instance_name == instance_name:
return self._return_factor * self._max_position_of_local_instance
pos = self._current_positions.get(
instance_name, self._persisted_upto_position instance_name, self._persisted_upto_position
) )
# We want to return the maximum "current token" that we can for a
# writer, this helps ensure that streams progress as fast as
# possible.
pos = max(pos, self._persisted_upto_position)
return self._return_factor * pos
def get_minimal_local_current_token(self) -> int:
with self._lock:
return self._return_factor * self._current_positions.get(
self._instance_name, self._persisted_upto_position
)
def get_positions(self) -> Dict[str, int]: def get_positions(self) -> Dict[str, int]:
"""Get a copy of the current positon map. """Get a copy of the current positon map.
@ -774,6 +828,18 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._persisted_upto_position = max(min_curr, self._persisted_upto_position) self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
# Advance our local max position.
self._max_position_of_local_instance = max(
self._max_position_of_local_instance, self._persisted_upto_position
)
if not self._unfinished_ids and not self._in_flight_fetches:
# If we don't have anything in flight, it's safe to advance to the
# max seen stream ID.
self._max_position_of_local_instance = max(
self._max_seen_allocated_stream_id, self._max_position_of_local_instance
)
# We now iterate through the seen positions, discarding those that are # We now iterate through the seen positions, discarding those that are
# less than the current min positions, and incrementing the min position # less than the current min positions, and incrementing the min position
# if its exactly one greater. # if its exactly one greater.

View file

@ -259,8 +259,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator() id_gen = self._create_id_generator()
# The table is empty so we expect an empty map for positions # The table is empty so we expect the map for positions to have a dummy
self.assertEqual(id_gen.get_positions(), {}) # minimum value.
self.assertEqual(id_gen.get_positions(), {"master": 1})
def test_single_instance(self) -> None: def test_single_instance(self) -> None:
"""Test that reads and writes from a single process are handled """Test that reads and writes from a single process are handled
@ -349,15 +350,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
first_id_gen = self._create_id_generator("first", writers=["first", "second"]) first_id_gen = self._create_id_generator("first", writers=["first", "second"])
second_id_gen = self._create_id_generator("second", writers=["first", "second"]) second_id_gen = self._create_id_generator("second", writers=["first", "second"])
# The first ID gen will notice that it can advance its token to 7 as it
# has no in progress writes...
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3) self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
# ... but the second ID gen doesn't know that.
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3) self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
# Try allocating a new ID gen and check that we only see position # Try allocating a new ID gen and check that we only see position
@ -398,6 +396,56 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen.advance("first", 8) second_id_gen.advance("first", 8)
self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9}) self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
def test_multi_instance_empty_row(self) -> None:
"""Test that reads and writes from multiple processes are handled
correctly, when one of the writers starts without any rows.
"""
# Insert some rows for two out of three of the ID gens.
self._insert_rows("first", 3)
self._insert_rows("second", 4)
first_id_gen = self._create_id_generator(
"first", writers=["first", "second", "third"]
)
second_id_gen = self._create_id_generator(
"second", writers=["first", "second", "third"]
)
third_id_gen = self._create_id_generator(
"third", writers=["first", "second", "third"]
)
self.assertEqual(
first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
)
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("third"), 7)
self.assertEqual(
second_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
)
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(second_id_gen.get_current_token_for_writer("third"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
async def _get_next_async() -> None:
async with third_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8)
self.assertEqual(
third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
)
self.assertEqual(third_id_gen.get_persisted_upto_position(), 7)
self.get_success(_get_next_async())
self.assertEqual(
third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 8}
)
def test_get_next_txn(self) -> None: def test_get_next_txn(self) -> None:
"""Test that the `get_next_txn` function works correctly.""" """Test that the `get_next_txn` function works correctly."""
@ -600,6 +648,70 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
with self.assertRaises(IncorrectDatabaseSetup): with self.assertRaises(IncorrectDatabaseSetup):
self._create_id_generator("first") self._create_id_generator("first")
def test_minimal_local_token(self) -> None:
self._insert_rows("first", 3)
self._insert_rows("second", 4)
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_minimal_local_current_token(), 3)
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(second_id_gen.get_minimal_local_current_token(), 7)
def test_current_token_gap(self) -> None:
"""Test that getting the current token for a writer returns the maximal
token when there are no writes.
"""
self._insert_rows("first", 3)
self._insert_rows("second", 4)
first_id_gen = self._create_id_generator(
"first", writers=["first", "second", "third"]
)
second_id_gen = self._create_id_generator(
"second", writers=["first", "second", "third"]
)
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(second_id_gen.get_current_token(), 7)
# Check that the first ID gen advancing causes the second ID gen to
# advance (as the second ID gen has nothing in flight).
async def _get_next_async() -> None:
async with first_id_gen.get_next_mult(2):
pass
self.get_success(_get_next_async())
second_id_gen.advance("first", 9)
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 9)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 9)
self.assertEqual(second_id_gen.get_current_token(), 7)
# Check that the first ID gen advancing doesn't advance the second ID
# gen when the second ID gen has stuff in flight.
self.get_success(_get_next_async())
ctxmgr = second_id_gen.get_next()
self.get_success(ctxmgr.__aenter__())
second_id_gen.advance("first", 11)
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 11)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 9)
self.assertEqual(second_id_gen.get_current_token(), 7)
self.get_success(ctxmgr.__aexit__(None, None, None))
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 11)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 12)
self.assertEqual(second_id_gen.get_current_token(), 7)
class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs.""" """Tests MultiWriterIdGenerator that produce *negative* stream IDs."""
@ -712,8 +824,8 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(_get_next_async()) self.get_success(_get_next_async())
self.assertEqual(id_gen_1.get_positions(), {"first": -1}) self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -1})
self.assertEqual(id_gen_2.get_positions(), {"first": -1}) self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -1})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1) self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1) self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
@ -822,11 +934,11 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen = self._create_id_generator("second", writers=["first", "second"]) second_id_gen = self._create_id_generator("second", writers=["first", "second"])
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6}) self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3) self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6) self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3) self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(second_id_gen.get_persisted_upto_position(), 7) self.assertEqual(second_id_gen.get_persisted_upto_position(), 7)