mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-22 01:25:44 +03:00
Support any process writing to cache invalidation stream. (#7436)
This commit is contained in:
parent
2929ce29d6
commit
d7983b63a6
26 changed files with 225 additions and 230 deletions
1
changelog.d/7436.misc
Normal file
1
changelog.d/7436.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Support any process writing to cache invalidation stream.
|
|
@ -219,10 +219,6 @@ Asks the server for the current position of all streams.
|
||||||
|
|
||||||
Inform the server a pusher should be removed
|
Inform the server a pusher should be removed
|
||||||
|
|
||||||
#### INVALIDATE_CACHE (C)
|
|
||||||
|
|
||||||
Inform the server a cache should be invalidated
|
|
||||||
|
|
||||||
### REMOTE_SERVER_UP (S, C)
|
### REMOTE_SERVER_UP (S, C)
|
||||||
|
|
||||||
Inform other processes that a remote server may have come back online.
|
Inform other processes that a remote server may have come back online.
|
||||||
|
|
|
@ -122,7 +122,7 @@ APPEND_ONLY_TABLES = [
|
||||||
"presence_stream",
|
"presence_stream",
|
||||||
"push_rules_stream",
|
"push_rules_stream",
|
||||||
"ex_outlier_stream",
|
"ex_outlier_stream",
|
||||||
"cache_invalidation_stream",
|
"cache_invalidation_stream_by_instance",
|
||||||
"public_room_list_stream",
|
"public_room_list_stream",
|
||||||
"state_group_edges",
|
"state_group_edges",
|
||||||
"stream_ordering_to_exterm",
|
"stream_ordering_to_exterm",
|
||||||
|
@ -188,7 +188,7 @@ class MockHomeserver:
|
||||||
self.clock = Clock(reactor)
|
self.clock = Clock(reactor)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.hostname = config.server_name
|
self.hostname = config.server_name
|
||||||
self.version_string = "Synapse/"+get_version_string(synapse)
|
self.version_string = "Synapse/" + get_version_string(synapse)
|
||||||
|
|
||||||
def get_clock(self):
|
def get_clock(self):
|
||||||
return self.clock
|
return self.clock
|
||||||
|
|
|
@ -18,14 +18,10 @@ from typing import Optional
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from synapse.storage.data_stores.main.cache import (
|
from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
|
||||||
CURRENT_STATE_CACHE_NAME,
|
|
||||||
CacheInvalidationWorkerStore,
|
|
||||||
)
|
|
||||||
from synapse.storage.database import Database
|
from synapse.storage.database import Database
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||||
from ._slaved_id_tracker import SlavedIdTracker
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -41,40 +37,16 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
|
||||||
def __init__(self, database: Database, db_conn, hs):
|
def __init__(self, database: Database, db_conn, hs):
|
||||||
super(BaseSlavedStore, self).__init__(database, db_conn, hs)
|
super(BaseSlavedStore, self).__init__(database, db_conn, hs)
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
self._cache_id_gen = SlavedIdTracker(
|
self._cache_id_gen = MultiWriterIdGenerator(
|
||||||
db_conn, "cache_invalidation_stream", "stream_id"
|
db_conn,
|
||||||
) # type: Optional[SlavedIdTracker]
|
database,
|
||||||
|
instance_name=hs.get_instance_name(),
|
||||||
|
table="cache_invalidation_stream_by_instance",
|
||||||
|
instance_column="instance_name",
|
||||||
|
id_column="stream_id",
|
||||||
|
sequence_name="cache_invalidation_stream_seq",
|
||||||
|
) # type: Optional[MultiWriterIdGenerator]
|
||||||
else:
|
else:
|
||||||
self._cache_id_gen = None
|
self._cache_id_gen = None
|
||||||
|
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
def get_cache_stream_token(self):
|
|
||||||
if self._cache_id_gen:
|
|
||||||
return self._cache_id_gen.get_current_token()
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
|
||||||
if stream_name == "caches":
|
|
||||||
if self._cache_id_gen:
|
|
||||||
self._cache_id_gen.advance(token)
|
|
||||||
for row in rows:
|
|
||||||
if row.cache_func == CURRENT_STATE_CACHE_NAME:
|
|
||||||
if row.keys is None:
|
|
||||||
raise Exception(
|
|
||||||
"Can't send an 'invalidate all' for current state cache"
|
|
||||||
)
|
|
||||||
|
|
||||||
room_id = row.keys[0]
|
|
||||||
members_changed = set(row.keys[1:])
|
|
||||||
self._invalidate_state_caches(room_id, members_changed)
|
|
||||||
else:
|
|
||||||
self._attempt_to_invalidate_cache(row.cache_func, row.keys)
|
|
||||||
|
|
||||||
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
|
|
||||||
txn.call_after(cache_func.invalidate, keys)
|
|
||||||
txn.call_after(self._send_invalidation_poke, cache_func, keys)
|
|
||||||
|
|
||||||
def _send_invalidation_poke(self, cache_func, keys):
|
|
||||||
self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys)
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
|
||||||
def get_max_account_data_stream_id(self):
|
def get_max_account_data_stream_id(self):
|
||||||
return self._account_data_id_gen.get_current_token()
|
return self._account_data_id_gen.get_current_token()
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||||
if stream_name == "tag_account_data":
|
if stream_name == "tag_account_data":
|
||||||
self._account_data_id_gen.advance(token)
|
self._account_data_id_gen.advance(token)
|
||||||
for row in rows:
|
for row in rows:
|
||||||
|
@ -51,6 +51,4 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
|
||||||
(row.user_id, row.room_id, row.data_type)
|
(row.user_id, row.room_id, row.data_type)
|
||||||
)
|
)
|
||||||
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
|
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
|
||||||
return super(SlavedAccountDataStore, self).process_replication_rows(
|
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||||
stream_name, token, rows
|
|
||||||
)
|
|
||||||
|
|
|
@ -43,7 +43,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
|
||||||
expiry_ms=30 * 60 * 1000,
|
expiry_ms=30 * 60 * 1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||||
if stream_name == "to_device":
|
if stream_name == "to_device":
|
||||||
self._device_inbox_id_gen.advance(token)
|
self._device_inbox_id_gen.advance(token)
|
||||||
for row in rows:
|
for row in rows:
|
||||||
|
@ -55,6 +55,4 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
|
||||||
self._device_federation_outbox_stream_cache.entity_has_changed(
|
self._device_federation_outbox_stream_cache.entity_has_changed(
|
||||||
row.entity, token
|
row.entity, token
|
||||||
)
|
)
|
||||||
return super(SlavedDeviceInboxStore, self).process_replication_rows(
|
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||||
stream_name, token, rows
|
|
||||||
)
|
|
||||||
|
|
|
@ -48,7 +48,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
|
||||||
"DeviceListFederationStreamChangeCache", device_list_max
|
"DeviceListFederationStreamChangeCache", device_list_max
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||||
if stream_name == DeviceListsStream.NAME:
|
if stream_name == DeviceListsStream.NAME:
|
||||||
self._device_list_id_gen.advance(token)
|
self._device_list_id_gen.advance(token)
|
||||||
self._invalidate_caches_for_devices(token, rows)
|
self._invalidate_caches_for_devices(token, rows)
|
||||||
|
@ -56,9 +56,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
|
||||||
self._device_list_id_gen.advance(token)
|
self._device_list_id_gen.advance(token)
|
||||||
for row in rows:
|
for row in rows:
|
||||||
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
|
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
|
||||||
return super(SlavedDeviceStore, self).process_replication_rows(
|
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||||
stream_name, token, rows
|
|
||||||
)
|
|
||||||
|
|
||||||
def _invalidate_caches_for_devices(self, token, rows):
|
def _invalidate_caches_for_devices(self, token, rows):
|
||||||
for row in rows:
|
for row in rows:
|
||||||
|
|
|
@ -93,7 +93,7 @@ class SlavedEventStore(
|
||||||
def get_room_min_stream_ordering(self):
|
def get_room_min_stream_ordering(self):
|
||||||
return self._backfill_id_gen.get_current_token()
|
return self._backfill_id_gen.get_current_token()
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||||
if stream_name == "events":
|
if stream_name == "events":
|
||||||
self._stream_id_gen.advance(token)
|
self._stream_id_gen.advance(token)
|
||||||
for row in rows:
|
for row in rows:
|
||||||
|
@ -111,9 +111,7 @@ class SlavedEventStore(
|
||||||
row.relates_to,
|
row.relates_to,
|
||||||
backfilled=True,
|
backfilled=True,
|
||||||
)
|
)
|
||||||
return super(SlavedEventStore, self).process_replication_rows(
|
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||||
stream_name, token, rows
|
|
||||||
)
|
|
||||||
|
|
||||||
def _process_event_stream_row(self, token, row):
|
def _process_event_stream_row(self, token, row):
|
||||||
data = row.data
|
data = row.data
|
||||||
|
|
|
@ -37,12 +37,10 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
|
||||||
def get_group_stream_token(self):
|
def get_group_stream_token(self):
|
||||||
return self._group_updates_id_gen.get_current_token()
|
return self._group_updates_id_gen.get_current_token()
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||||
if stream_name == "groups":
|
if stream_name == "groups":
|
||||||
self._group_updates_id_gen.advance(token)
|
self._group_updates_id_gen.advance(token)
|
||||||
for row in rows:
|
for row in rows:
|
||||||
self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
|
self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
|
||||||
|
|
||||||
return super(SlavedGroupServerStore, self).process_replication_rows(
|
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||||
stream_name, token, rows
|
|
||||||
)
|
|
||||||
|
|
|
@ -41,12 +41,10 @@ class SlavedPresenceStore(BaseSlavedStore):
|
||||||
def get_current_presence_token(self):
|
def get_current_presence_token(self):
|
||||||
return self._presence_id_gen.get_current_token()
|
return self._presence_id_gen.get_current_token()
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||||
if stream_name == "presence":
|
if stream_name == "presence":
|
||||||
self._presence_id_gen.advance(token)
|
self._presence_id_gen.advance(token)
|
||||||
for row in rows:
|
for row in rows:
|
||||||
self.presence_stream_cache.entity_has_changed(row.user_id, token)
|
self.presence_stream_cache.entity_has_changed(row.user_id, token)
|
||||||
self._get_presence_for_user.invalidate((row.user_id,))
|
self._get_presence_for_user.invalidate((row.user_id,))
|
||||||
return super(SlavedPresenceStore, self).process_replication_rows(
|
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||||
stream_name, token, rows
|
|
||||||
)
|
|
||||||
|
|
|
@ -37,13 +37,11 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
|
||||||
def get_max_push_rules_stream_id(self):
|
def get_max_push_rules_stream_id(self):
|
||||||
return self._push_rules_stream_id_gen.get_current_token()
|
return self._push_rules_stream_id_gen.get_current_token()
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||||
if stream_name == "push_rules":
|
if stream_name == "push_rules":
|
||||||
self._push_rules_stream_id_gen.advance(token)
|
self._push_rules_stream_id_gen.advance(token)
|
||||||
for row in rows:
|
for row in rows:
|
||||||
self.get_push_rules_for_user.invalidate((row.user_id,))
|
self.get_push_rules_for_user.invalidate((row.user_id,))
|
||||||
self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
|
self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
|
||||||
self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
|
self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
|
||||||
return super(SlavedPushRuleStore, self).process_replication_rows(
|
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||||
stream_name, token, rows
|
|
||||||
)
|
|
||||||
|
|
|
@ -31,9 +31,7 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
|
||||||
def get_pushers_stream_token(self):
|
def get_pushers_stream_token(self):
|
||||||
return self._pushers_id_gen.get_current_token()
|
return self._pushers_id_gen.get_current_token()
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||||
if stream_name == "pushers":
|
if stream_name == "pushers":
|
||||||
self._pushers_id_gen.advance(token)
|
self._pushers_id_gen.advance(token)
|
||||||
return super(SlavedPusherStore, self).process_replication_rows(
|
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||||
stream_name, token, rows
|
|
||||||
)
|
|
||||||
|
|
|
@ -51,7 +51,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
|
||||||
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
|
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
|
||||||
self.get_receipts_for_room.invalidate((room_id, receipt_type))
|
self.get_receipts_for_room.invalidate((room_id, receipt_type))
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||||
if stream_name == "receipts":
|
if stream_name == "receipts":
|
||||||
self._receipts_id_gen.advance(token)
|
self._receipts_id_gen.advance(token)
|
||||||
for row in rows:
|
for row in rows:
|
||||||
|
@ -60,6 +60,4 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
|
||||||
)
|
)
|
||||||
self._receipts_stream_cache.entity_has_changed(row.room_id, token)
|
self._receipts_stream_cache.entity_has_changed(row.room_id, token)
|
||||||
|
|
||||||
return super(SlavedReceiptsStore, self).process_replication_rows(
|
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||||
stream_name, token, rows
|
|
||||||
)
|
|
||||||
|
|
|
@ -30,8 +30,8 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
|
||||||
def get_current_public_room_stream_id(self):
|
def get_current_public_room_stream_id(self):
|
||||||
return self._public_room_id_gen.get_current_token()
|
return self._public_room_id_gen.get_current_token()
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||||
if stream_name == "public_rooms":
|
if stream_name == "public_rooms":
|
||||||
self._public_room_id_gen.advance(token)
|
self._public_room_id_gen.advance(token)
|
||||||
|
|
||||||
return super(RoomStore, self).process_replication_rows(stream_name, token, rows)
|
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||||
|
|
|
@ -100,10 +100,10 @@ class ReplicationDataHandler:
|
||||||
token: stream token for this batch of rows
|
token: stream token for this batch of rows
|
||||||
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
|
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
|
||||||
"""
|
"""
|
||||||
self.store.process_replication_rows(stream_name, token, rows)
|
self.store.process_replication_rows(stream_name, instance_name, token, rows)
|
||||||
|
|
||||||
async def on_position(self, stream_name: str, token: int):
|
async def on_position(self, stream_name: str, instance_name: str, token: int):
|
||||||
self.store.process_replication_rows(stream_name, token, [])
|
self.store.process_replication_rows(stream_name, instance_name, token, [])
|
||||||
|
|
||||||
def on_remote_server_up(self, server: str):
|
def on_remote_server_up(self, server: str):
|
||||||
"""Called when get a new REMOTE_SERVER_UP command."""
|
"""Called when get a new REMOTE_SERVER_UP command."""
|
||||||
|
|
|
@ -341,37 +341,6 @@ class RemovePusherCommand(Command):
|
||||||
return " ".join((self.app_id, self.push_key, self.user_id))
|
return " ".join((self.app_id, self.push_key, self.user_id))
|
||||||
|
|
||||||
|
|
||||||
class InvalidateCacheCommand(Command):
|
|
||||||
"""Sent by the client to invalidate an upstream cache.
|
|
||||||
|
|
||||||
THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE
|
|
||||||
NOT DISASTROUS IF WE DROP ON THE FLOOR.
|
|
||||||
|
|
||||||
Mainly used to invalidate destination retry timing caches.
|
|
||||||
|
|
||||||
Format::
|
|
||||||
|
|
||||||
INVALIDATE_CACHE <cache_func> <keys_json>
|
|
||||||
|
|
||||||
Where <keys_json> is a json list.
|
|
||||||
"""
|
|
||||||
|
|
||||||
NAME = "INVALIDATE_CACHE"
|
|
||||||
|
|
||||||
def __init__(self, cache_func, keys):
|
|
||||||
self.cache_func = cache_func
|
|
||||||
self.keys = keys
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_line(cls, line):
|
|
||||||
cache_func, keys_json = line.split(" ", 1)
|
|
||||||
|
|
||||||
return cls(cache_func, json.loads(keys_json))
|
|
||||||
|
|
||||||
def to_line(self):
|
|
||||||
return " ".join((self.cache_func, _json_encoder.encode(self.keys)))
|
|
||||||
|
|
||||||
|
|
||||||
class UserIpCommand(Command):
|
class UserIpCommand(Command):
|
||||||
"""Sent periodically when a worker sees activity from a client.
|
"""Sent periodically when a worker sees activity from a client.
|
||||||
|
|
||||||
|
@ -439,7 +408,6 @@ _COMMANDS = (
|
||||||
UserSyncCommand,
|
UserSyncCommand,
|
||||||
FederationAckCommand,
|
FederationAckCommand,
|
||||||
RemovePusherCommand,
|
RemovePusherCommand,
|
||||||
InvalidateCacheCommand,
|
|
||||||
UserIpCommand,
|
UserIpCommand,
|
||||||
RemoteServerUpCommand,
|
RemoteServerUpCommand,
|
||||||
ClearUserSyncsCommand,
|
ClearUserSyncsCommand,
|
||||||
|
@ -467,7 +435,6 @@ VALID_CLIENT_COMMANDS = (
|
||||||
ClearUserSyncsCommand.NAME,
|
ClearUserSyncsCommand.NAME,
|
||||||
FederationAckCommand.NAME,
|
FederationAckCommand.NAME,
|
||||||
RemovePusherCommand.NAME,
|
RemovePusherCommand.NAME,
|
||||||
InvalidateCacheCommand.NAME,
|
|
||||||
UserIpCommand.NAME,
|
UserIpCommand.NAME,
|
||||||
ErrorCommand.NAME,
|
ErrorCommand.NAME,
|
||||||
RemoteServerUpCommand.NAME,
|
RemoteServerUpCommand.NAME,
|
||||||
|
|
|
@ -15,18 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
Iterable,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Set,
|
|
||||||
Tuple,
|
|
||||||
TypeVar,
|
|
||||||
)
|
|
||||||
|
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
|
||||||
|
@ -38,7 +27,6 @@ from synapse.replication.tcp.commands import (
|
||||||
ClearUserSyncsCommand,
|
ClearUserSyncsCommand,
|
||||||
Command,
|
Command,
|
||||||
FederationAckCommand,
|
FederationAckCommand,
|
||||||
InvalidateCacheCommand,
|
|
||||||
PositionCommand,
|
PositionCommand,
|
||||||
RdataCommand,
|
RdataCommand,
|
||||||
RemoteServerUpCommand,
|
RemoteServerUpCommand,
|
||||||
|
@ -171,7 +159,7 @@ class ReplicationCommandHandler:
|
||||||
return
|
return
|
||||||
|
|
||||||
for stream_name, stream in self._streams.items():
|
for stream_name, stream in self._streams.items():
|
||||||
current_token = stream.current_token()
|
current_token = stream.current_token(self._instance_name)
|
||||||
self.send_command(
|
self.send_command(
|
||||||
PositionCommand(stream_name, self._instance_name, current_token)
|
PositionCommand(stream_name, self._instance_name, current_token)
|
||||||
)
|
)
|
||||||
|
@ -210,18 +198,6 @@ class ReplicationCommandHandler:
|
||||||
|
|
||||||
self._notifier.on_new_replication_data()
|
self._notifier.on_new_replication_data()
|
||||||
|
|
||||||
async def on_INVALIDATE_CACHE(
|
|
||||||
self, conn: AbstractConnection, cmd: InvalidateCacheCommand
|
|
||||||
):
|
|
||||||
invalidate_cache_counter.inc()
|
|
||||||
|
|
||||||
if self._is_master:
|
|
||||||
# We invalidate the cache locally, but then also stream that to other
|
|
||||||
# workers.
|
|
||||||
await self._store.invalidate_cache_and_stream(
|
|
||||||
cmd.cache_func, tuple(cmd.keys)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
|
async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
|
||||||
user_ip_cache_counter.inc()
|
user_ip_cache_counter.inc()
|
||||||
|
|
||||||
|
@ -295,7 +271,7 @@ class ReplicationCommandHandler:
|
||||||
rows: a list of Stream.ROW_TYPE objects as returned by
|
rows: a list of Stream.ROW_TYPE objects as returned by
|
||||||
Stream.parse_row.
|
Stream.parse_row.
|
||||||
"""
|
"""
|
||||||
logger.debug("Received rdata %s -> %s", stream_name, token)
|
logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token)
|
||||||
await self._replication_data_handler.on_rdata(
|
await self._replication_data_handler.on_rdata(
|
||||||
stream_name, instance_name, token, rows
|
stream_name, instance_name, token, rows
|
||||||
)
|
)
|
||||||
|
@ -326,7 +302,7 @@ class ReplicationCommandHandler:
|
||||||
self._pending_batches.pop(stream_name, [])
|
self._pending_batches.pop(stream_name, [])
|
||||||
|
|
||||||
# Find where we previously streamed up to.
|
# Find where we previously streamed up to.
|
||||||
current_token = stream.current_token()
|
current_token = stream.current_token(cmd.instance_name)
|
||||||
|
|
||||||
# If the position token matches our current token then we're up to
|
# If the position token matches our current token then we're up to
|
||||||
# date and there's nothing to do. Otherwise, fetch all updates
|
# date and there's nothing to do. Otherwise, fetch all updates
|
||||||
|
@ -363,7 +339,9 @@ class ReplicationCommandHandler:
|
||||||
logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
|
logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
|
||||||
|
|
||||||
# We've now caught up to position sent to us, notify handler.
|
# We've now caught up to position sent to us, notify handler.
|
||||||
await self._replication_data_handler.on_position(stream_name, cmd.token)
|
await self._replication_data_handler.on_position(
|
||||||
|
cmd.stream_name, cmd.instance_name, cmd.token
|
||||||
|
)
|
||||||
|
|
||||||
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
|
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
|
||||||
|
|
||||||
|
@ -491,12 +469,6 @@ class ReplicationCommandHandler:
|
||||||
cmd = RemovePusherCommand(app_id, push_key, user_id)
|
cmd = RemovePusherCommand(app_id, push_key, user_id)
|
||||||
self.send_command(cmd)
|
self.send_command(cmd)
|
||||||
|
|
||||||
def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
|
|
||||||
"""Poke the master to invalidate a cache.
|
|
||||||
"""
|
|
||||||
cmd = InvalidateCacheCommand(cache_func.__name__, keys)
|
|
||||||
self.send_command(cmd)
|
|
||||||
|
|
||||||
def send_user_ip(
|
def send_user_ip(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
|
|
@ -25,7 +25,12 @@ from twisted.internet.protocol import Factory
|
||||||
|
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
|
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
|
||||||
from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream
|
from synapse.replication.tcp.streams import (
|
||||||
|
STREAMS_MAP,
|
||||||
|
CachesStream,
|
||||||
|
FederationStream,
|
||||||
|
Stream,
|
||||||
|
)
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
stream_updates_counter = Counter(
|
stream_updates_counter = Counter(
|
||||||
|
@ -71,11 +76,16 @@ class ReplicationStreamer(object):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
|
self._instance_name = hs.get_instance_name()
|
||||||
|
|
||||||
self._replication_torture_level = hs.config.replication_torture_level
|
self._replication_torture_level = hs.config.replication_torture_level
|
||||||
|
|
||||||
# Work out list of streams that this instance is the source of.
|
# Work out list of streams that this instance is the source of.
|
||||||
self.streams = [] # type: List[Stream]
|
self.streams = [] # type: List[Stream]
|
||||||
|
|
||||||
|
# All workers can write to the cache invalidation stream.
|
||||||
|
self.streams.append(CachesStream(hs))
|
||||||
|
|
||||||
if hs.config.worker_app is None:
|
if hs.config.worker_app is None:
|
||||||
for stream in STREAMS_MAP.values():
|
for stream in STREAMS_MAP.values():
|
||||||
if stream == FederationStream and hs.config.send_federation:
|
if stream == FederationStream and hs.config.send_federation:
|
||||||
|
@ -83,6 +93,10 @@ class ReplicationStreamer(object):
|
||||||
# has been disabled on the master.
|
# has been disabled on the master.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if stream == CachesStream:
|
||||||
|
# We've already added it above.
|
||||||
|
continue
|
||||||
|
|
||||||
self.streams.append(stream(hs))
|
self.streams.append(stream(hs))
|
||||||
|
|
||||||
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
|
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
|
||||||
|
@ -145,7 +159,9 @@ class ReplicationStreamer(object):
|
||||||
random.shuffle(all_streams)
|
random.shuffle(all_streams)
|
||||||
|
|
||||||
for stream in all_streams:
|
for stream in all_streams:
|
||||||
if stream.last_token == stream.current_token():
|
if stream.last_token == stream.current_token(
|
||||||
|
self._instance_name
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if self._replication_torture_level:
|
if self._replication_torture_level:
|
||||||
|
@ -157,7 +173,7 @@ class ReplicationStreamer(object):
|
||||||
"Getting stream: %s: %s -> %s",
|
"Getting stream: %s: %s -> %s",
|
||||||
stream.NAME,
|
stream.NAME,
|
||||||
stream.last_token,
|
stream.last_token,
|
||||||
stream.current_token(),
|
stream.current_token(self._instance_name),
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
updates, current_token, limited = await stream.get_updates()
|
updates, current_token, limited = await stream.get_updates()
|
||||||
|
|
|
@ -95,20 +95,25 @@ class Stream(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
local_instance_name: str,
|
local_instance_name: str,
|
||||||
current_token_function: Callable[[], Token],
|
current_token_function: Callable[[str], Token],
|
||||||
update_function: UpdateFunction,
|
update_function: UpdateFunction,
|
||||||
):
|
):
|
||||||
"""Instantiate a Stream
|
"""Instantiate a Stream
|
||||||
|
|
||||||
current_token_function and update_function are callbacks which should be
|
`current_token_function` and `update_function` are callbacks which
|
||||||
implemented by subclasses.
|
should be implemented by subclasses.
|
||||||
|
|
||||||
current_token_function is called to get the current token of the underlying
|
`current_token_function` takes an instance name, which is a writer to
|
||||||
stream. It is only meaningful on the process that is the source of the
|
the stream, and returns the position in the stream of the writer (as
|
||||||
replication stream (ie, usually the master).
|
viewed from the current process). On the writer process this is where
|
||||||
|
the writer has successfully written up to, whereas on other processes
|
||||||
|
this is the position which we have received updates up to over
|
||||||
|
replication. (Note that most streams have a single writer and so their
|
||||||
|
implementations ignore the instance name passed in).
|
||||||
|
|
||||||
update_function is called to get updates for this stream between a pair of
|
`update_function` is called to get updates for this stream between a
|
||||||
stream tokens. See the UpdateFunction type definition for more info.
|
pair of stream tokens. See the `UpdateFunction` type definition for more
|
||||||
|
info.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
local_instance_name: The instance name of the current process
|
local_instance_name: The instance name of the current process
|
||||||
|
@ -120,13 +125,13 @@ class Stream(object):
|
||||||
self.update_function = update_function
|
self.update_function = update_function
|
||||||
|
|
||||||
# The token from which we last asked for updates
|
# The token from which we last asked for updates
|
||||||
self.last_token = self.current_token()
|
self.last_token = self.current_token(self.local_instance_name)
|
||||||
|
|
||||||
def discard_updates_and_advance(self):
|
def discard_updates_and_advance(self):
|
||||||
"""Called when the stream should advance but the updates would be discarded,
|
"""Called when the stream should advance but the updates would be discarded,
|
||||||
e.g. when there are no currently connected workers.
|
e.g. when there are no currently connected workers.
|
||||||
"""
|
"""
|
||||||
self.last_token = self.current_token()
|
self.last_token = self.current_token(self.local_instance_name)
|
||||||
|
|
||||||
async def get_updates(self) -> StreamUpdateResult:
|
async def get_updates(self) -> StreamUpdateResult:
|
||||||
"""Gets all updates since the last time this function was called (or
|
"""Gets all updates since the last time this function was called (or
|
||||||
|
@ -138,7 +143,7 @@ class Stream(object):
|
||||||
position in stream, and `limited` is whether there are more updates
|
position in stream, and `limited` is whether there are more updates
|
||||||
to fetch.
|
to fetch.
|
||||||
"""
|
"""
|
||||||
current_token = self.current_token()
|
current_token = self.current_token(self.local_instance_name)
|
||||||
updates, current_token, limited = await self.get_updates_since(
|
updates, current_token, limited = await self.get_updates_since(
|
||||||
self.local_instance_name, self.last_token, current_token
|
self.local_instance_name, self.last_token, current_token
|
||||||
)
|
)
|
||||||
|
@ -170,6 +175,16 @@ class Stream(object):
|
||||||
return updates, upto_token, limited
|
return updates, upto_token, limited
|
||||||
|
|
||||||
|
|
||||||
|
def current_token_without_instance(
|
||||||
|
current_token: Callable[[], int]
|
||||||
|
) -> Callable[[str], int]:
|
||||||
|
"""Takes a current token callback function for a single writer stream
|
||||||
|
that doesn't take an instance name parameter and wraps it in a function that
|
||||||
|
does accept an instance name parameter but ignores it.
|
||||||
|
"""
|
||||||
|
return lambda instance_name: current_token()
|
||||||
|
|
||||||
|
|
||||||
def db_query_to_update_function(
|
def db_query_to_update_function(
|
||||||
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
|
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
|
||||||
) -> UpdateFunction:
|
) -> UpdateFunction:
|
||||||
|
@ -235,7 +250,7 @@ class BackfillStream(Stream):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
hs.get_instance_name(),
|
hs.get_instance_name(),
|
||||||
store.get_current_backfill_token,
|
current_token_without_instance(store.get_current_backfill_token),
|
||||||
db_query_to_update_function(store.get_all_new_backfill_event_rows),
|
db_query_to_update_function(store.get_all_new_backfill_event_rows),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -271,7 +286,9 @@ class PresenceStream(Stream):
|
||||||
update_function = make_http_update_function(hs, self.NAME)
|
update_function = make_http_update_function(hs, self.NAME)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
hs.get_instance_name(), store.get_current_presence_token, update_function
|
hs.get_instance_name(),
|
||||||
|
current_token_without_instance(store.get_current_presence_token),
|
||||||
|
update_function,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -296,7 +313,9 @@ class TypingStream(Stream):
|
||||||
update_function = make_http_update_function(hs, self.NAME)
|
update_function = make_http_update_function(hs, self.NAME)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
hs.get_instance_name(), typing_handler.get_current_token, update_function
|
hs.get_instance_name(),
|
||||||
|
current_token_without_instance(typing_handler.get_current_token),
|
||||||
|
update_function,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -319,7 +338,7 @@ class ReceiptsStream(Stream):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
hs.get_instance_name(),
|
hs.get_instance_name(),
|
||||||
store.get_max_receipt_stream_id,
|
current_token_without_instance(store.get_max_receipt_stream_id),
|
||||||
db_query_to_update_function(store.get_all_updated_receipts),
|
db_query_to_update_function(store.get_all_updated_receipts),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -339,7 +358,7 @@ class PushRulesStream(Stream):
|
||||||
hs.get_instance_name(), self._current_token, self._update_function
|
hs.get_instance_name(), self._current_token, self._update_function
|
||||||
)
|
)
|
||||||
|
|
||||||
def _current_token(self) -> int:
|
def _current_token(self, instance_name: str) -> int:
|
||||||
push_rules_token, _ = self.store.get_push_rules_stream_token()
|
push_rules_token, _ = self.store.get_push_rules_stream_token()
|
||||||
return push_rules_token
|
return push_rules_token
|
||||||
|
|
||||||
|
@ -373,7 +392,7 @@ class PushersStream(Stream):
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
hs.get_instance_name(),
|
hs.get_instance_name(),
|
||||||
store.get_pushers_stream_token,
|
current_token_without_instance(store.get_pushers_stream_token),
|
||||||
db_query_to_update_function(store.get_all_updated_pushers_rows),
|
db_query_to_update_function(store.get_all_updated_pushers_rows),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -402,13 +421,27 @@ class CachesStream(Stream):
|
||||||
ROW_TYPE = CachesStreamRow
|
ROW_TYPE = CachesStreamRow
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
hs.get_instance_name(),
|
hs.get_instance_name(),
|
||||||
store.get_cache_stream_token,
|
self.store.get_cache_stream_token,
|
||||||
db_query_to_update_function(store.get_all_updated_caches),
|
self._update_function,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _update_function(
|
||||||
|
self, instance_name: str, from_token: int, upto_token: int, limit: int
|
||||||
|
):
|
||||||
|
rows = await self.store.get_all_updated_caches(
|
||||||
|
instance_name, from_token, upto_token, limit
|
||||||
|
)
|
||||||
|
updates = [(row[0], row[1:]) for row in rows]
|
||||||
|
limited = False
|
||||||
|
if len(updates) >= limit:
|
||||||
|
upto_token = updates[-1][0]
|
||||||
|
limited = True
|
||||||
|
|
||||||
|
return updates, upto_token, limited
|
||||||
|
|
||||||
|
|
||||||
class PublicRoomsStream(Stream):
|
class PublicRoomsStream(Stream):
|
||||||
"""The public rooms list changed
|
"""The public rooms list changed
|
||||||
|
@ -431,7 +464,7 @@ class PublicRoomsStream(Stream):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
hs.get_instance_name(),
|
hs.get_instance_name(),
|
||||||
store.get_current_public_room_stream_id,
|
current_token_without_instance(store.get_current_public_room_stream_id),
|
||||||
db_query_to_update_function(store.get_all_new_public_rooms),
|
db_query_to_update_function(store.get_all_new_public_rooms),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -452,7 +485,7 @@ class DeviceListsStream(Stream):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
hs.get_instance_name(),
|
hs.get_instance_name(),
|
||||||
store.get_device_stream_token,
|
current_token_without_instance(store.get_device_stream_token),
|
||||||
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
|
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -470,7 +503,7 @@ class ToDeviceStream(Stream):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
hs.get_instance_name(),
|
hs.get_instance_name(),
|
||||||
store.get_to_device_stream_token,
|
current_token_without_instance(store.get_to_device_stream_token),
|
||||||
db_query_to_update_function(store.get_all_new_device_messages),
|
db_query_to_update_function(store.get_all_new_device_messages),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -490,7 +523,7 @@ class TagAccountDataStream(Stream):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
hs.get_instance_name(),
|
hs.get_instance_name(),
|
||||||
store.get_max_account_data_stream_id,
|
current_token_without_instance(store.get_max_account_data_stream_id),
|
||||||
db_query_to_update_function(store.get_all_updated_tags),
|
db_query_to_update_function(store.get_all_updated_tags),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -510,7 +543,7 @@ class AccountDataStream(Stream):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
hs.get_instance_name(),
|
hs.get_instance_name(),
|
||||||
self.store.get_max_account_data_stream_id,
|
current_token_without_instance(self.store.get_max_account_data_stream_id),
|
||||||
db_query_to_update_function(self._update_function),
|
db_query_to_update_function(self._update_function),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -541,7 +574,7 @@ class GroupServerStream(Stream):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
hs.get_instance_name(),
|
hs.get_instance_name(),
|
||||||
store.get_group_stream_token,
|
current_token_without_instance(store.get_group_stream_token),
|
||||||
db_query_to_update_function(store.get_all_groups_changes),
|
db_query_to_update_function(store.get_all_groups_changes),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -559,7 +592,7 @@ class UserSignatureStream(Stream):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
hs.get_instance_name(),
|
hs.get_instance_name(),
|
||||||
store.get_device_stream_token,
|
current_token_without_instance(store.get_device_stream_token),
|
||||||
db_query_to_update_function(
|
db_query_to_update_function(
|
||||||
store.get_all_user_signature_changes_for_remotes
|
store.get_all_user_signature_changes_for_remotes
|
||||||
),
|
),
|
||||||
|
|
|
@ -20,7 +20,7 @@ from typing import List, Tuple, Type
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from ._base import Stream, StreamUpdateResult, Token
|
from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
|
||||||
|
|
||||||
|
|
||||||
"""Handling of the 'events' replication stream
|
"""Handling of the 'events' replication stream
|
||||||
|
@ -119,7 +119,7 @@ class EventsStream(Stream):
|
||||||
self._store = hs.get_datastore()
|
self._store = hs.get_datastore()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
hs.get_instance_name(),
|
hs.get_instance_name(),
|
||||||
self._store.get_current_events_token,
|
current_token_without_instance(self._store.get_current_events_token),
|
||||||
self._update_function,
|
self._update_function,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
from synapse.replication.tcp.streams._base import Stream, make_http_update_function
|
from synapse.replication.tcp.streams._base import (
|
||||||
|
Stream,
|
||||||
|
current_token_without_instance,
|
||||||
|
make_http_update_function,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FederationStream(Stream):
|
class FederationStream(Stream):
|
||||||
|
@ -41,7 +45,9 @@ class FederationStream(Stream):
|
||||||
# will be a real FederationSender, which has stubs for current_token and
|
# will be a real FederationSender, which has stubs for current_token and
|
||||||
# get_replication_rows.)
|
# get_replication_rows.)
|
||||||
federation_sender = hs.get_federation_sender()
|
federation_sender = hs.get_federation_sender()
|
||||||
current_token = federation_sender.get_current_token
|
current_token = current_token_without_instance(
|
||||||
|
federation_sender.get_current_token
|
||||||
|
)
|
||||||
update_function = federation_sender.get_replication_rows
|
update_function = federation_sender.get_replication_rows
|
||||||
|
|
||||||
elif hs.should_send_federation():
|
elif hs.should_send_federation():
|
||||||
|
@ -58,7 +64,7 @@ class FederationStream(Stream):
|
||||||
super().__init__(hs.get_instance_name(), current_token, update_function)
|
super().__init__(hs.get_instance_name(), current_token, update_function)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _stub_current_token():
|
def _stub_current_token(instance_name: str) -> int:
|
||||||
# dummy current-token method for use on workers
|
# dummy current-token method for use on workers
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
|
@ -47,6 +47,9 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||||
self.db = database
|
self.db = database
|
||||||
self.rand = random.SystemRandom()
|
self.rand = random.SystemRandom()
|
||||||
|
|
||||||
|
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||||
|
pass
|
||||||
|
|
||||||
def _invalidate_state_caches(self, room_id, members_changed):
|
def _invalidate_state_caches(self, room_id, members_changed):
|
||||||
"""Invalidates caches that are based on the current state, but does
|
"""Invalidates caches that are based on the current state, but does
|
||||||
not stream invalidations down replication.
|
not stream invalidations down replication.
|
||||||
|
|
|
@ -26,13 +26,14 @@ from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.storage.util.id_generators import (
|
from synapse.storage.util.id_generators import (
|
||||||
ChainedIdGenerator,
|
ChainedIdGenerator,
|
||||||
IdGenerator,
|
IdGenerator,
|
||||||
|
MultiWriterIdGenerator,
|
||||||
StreamIdGenerator,
|
StreamIdGenerator,
|
||||||
)
|
)
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
|
|
||||||
from .account_data import AccountDataStore
|
from .account_data import AccountDataStore
|
||||||
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
|
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
|
||||||
from .cache import CacheInvalidationStore
|
from .cache import CacheInvalidationWorkerStore
|
||||||
from .client_ips import ClientIpStore
|
from .client_ips import ClientIpStore
|
||||||
from .deviceinbox import DeviceInboxStore
|
from .deviceinbox import DeviceInboxStore
|
||||||
from .devices import DeviceStore
|
from .devices import DeviceStore
|
||||||
|
@ -112,8 +113,8 @@ class DataStore(
|
||||||
MonthlyActiveUsersStore,
|
MonthlyActiveUsersStore,
|
||||||
StatsStore,
|
StatsStore,
|
||||||
RelationsStore,
|
RelationsStore,
|
||||||
CacheInvalidationStore,
|
|
||||||
UIAuthStore,
|
UIAuthStore,
|
||||||
|
CacheInvalidationWorkerStore,
|
||||||
):
|
):
|
||||||
def __init__(self, database: Database, db_conn, hs):
|
def __init__(self, database: Database, db_conn, hs):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
@ -170,8 +171,14 @@ class DataStore(
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
self._cache_id_gen = StreamIdGenerator(
|
self._cache_id_gen = MultiWriterIdGenerator(
|
||||||
db_conn, "cache_invalidation_stream", "stream_id"
|
db_conn,
|
||||||
|
database,
|
||||||
|
instance_name="master",
|
||||||
|
table="cache_invalidation_stream_by_instance",
|
||||||
|
instance_column="instance_name",
|
||||||
|
id_column="stream_id",
|
||||||
|
sequence_name="cache_invalidation_stream_seq",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._cache_id_gen = None
|
self._cache_id_gen = None
|
||||||
|
|
|
@ -16,11 +16,10 @@
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Iterable, Optional, Tuple
|
from typing import Any, Iterable, Optional
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
|
from synapse.storage.database import Database
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
|
|
||||||
|
@ -33,47 +32,58 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
|
||||||
|
|
||||||
|
|
||||||
class CacheInvalidationWorkerStore(SQLBaseStore):
|
class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
def get_all_updated_caches(self, last_id, current_id, limit):
|
def __init__(self, database: Database, db_conn, hs):
|
||||||
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
|
self._instance_name = hs.get_instance_name()
|
||||||
|
|
||||||
|
async def get_all_updated_caches(
|
||||||
|
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||||
|
):
|
||||||
|
"""Fetches cache invalidation rows between the two given IDs written
|
||||||
|
by the given instance. Returns at most `limit` rows.
|
||||||
|
"""
|
||||||
|
|
||||||
if last_id == current_id:
|
if last_id == current_id:
|
||||||
return defer.succeed([])
|
return []
|
||||||
|
|
||||||
def get_all_updated_caches_txn(txn):
|
def get_all_updated_caches_txn(txn):
|
||||||
# We purposefully don't bound by the current token, as we want to
|
# We purposefully don't bound by the current token, as we want to
|
||||||
# send across cache invalidations as quickly as possible. Cache
|
# send across cache invalidations as quickly as possible. Cache
|
||||||
# invalidations are idempotent, so duplicates are fine.
|
# invalidations are idempotent, so duplicates are fine.
|
||||||
sql = (
|
sql = """
|
||||||
"SELECT stream_id, cache_func, keys, invalidation_ts"
|
SELECT stream_id, cache_func, keys, invalidation_ts
|
||||||
" FROM cache_invalidation_stream"
|
FROM cache_invalidation_stream_by_instance
|
||||||
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
|
WHERE stream_id > ? AND instance_name = ?
|
||||||
)
|
ORDER BY stream_id ASC
|
||||||
txn.execute(sql, (last_id, limit))
|
LIMIT ?
|
||||||
|
"""
|
||||||
|
txn.execute(sql, (last_id, instance_name, limit))
|
||||||
return txn.fetchall()
|
return txn.fetchall()
|
||||||
|
|
||||||
return self.db.runInteraction(
|
return await self.db.runInteraction(
|
||||||
"get_all_updated_caches", get_all_updated_caches_txn
|
"get_all_updated_caches", get_all_updated_caches_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||||
|
if stream_name == "caches":
|
||||||
|
if self._cache_id_gen:
|
||||||
|
self._cache_id_gen.advance(instance_name, token)
|
||||||
|
|
||||||
class CacheInvalidationStore(CacheInvalidationWorkerStore):
|
for row in rows:
|
||||||
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
|
if row.cache_func == CURRENT_STATE_CACHE_NAME:
|
||||||
"""Invalidates the cache and adds it to the cache stream so slaves
|
if row.keys is None:
|
||||||
will know to invalidate their caches.
|
raise Exception(
|
||||||
|
"Can't send an 'invalidate all' for current state cache"
|
||||||
|
)
|
||||||
|
|
||||||
This should only be used to invalidate caches where slaves won't
|
room_id = row.keys[0]
|
||||||
otherwise know from other replication streams that the cache should
|
members_changed = set(row.keys[1:])
|
||||||
be invalidated.
|
self._invalidate_state_caches(room_id, members_changed)
|
||||||
"""
|
else:
|
||||||
cache_func = getattr(self, cache_name, None)
|
self._attempt_to_invalidate_cache(row.cache_func, row.keys)
|
||||||
if not cache_func:
|
|
||||||
return
|
|
||||||
|
|
||||||
cache_func.invalidate(keys)
|
super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||||
await self.runInteraction(
|
|
||||||
"invalidate_cache_and_stream",
|
|
||||||
self._send_invalidation_to_replication,
|
|
||||||
cache_func.__name__,
|
|
||||||
keys,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
|
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
|
||||||
"""Invalidates the cache and adds it to the cache stream so slaves
|
"""Invalidates the cache and adds it to the cache stream so slaves
|
||||||
|
@ -147,10 +157,7 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
|
||||||
# the transaction. However, we want to only get an ID when we want
|
# the transaction. However, we want to only get an ID when we want
|
||||||
# to use it, here, so we need to call __enter__ manually, and have
|
# to use it, here, so we need to call __enter__ manually, and have
|
||||||
# __exit__ called after the transaction finishes.
|
# __exit__ called after the transaction finishes.
|
||||||
ctx = self._cache_id_gen.get_next()
|
stream_id = self._cache_id_gen.get_next_txn(txn)
|
||||||
stream_id = ctx.__enter__()
|
|
||||||
txn.call_on_exception(ctx.__exit__, None, None, None)
|
|
||||||
txn.call_after(ctx.__exit__, None, None, None)
|
|
||||||
txn.call_after(self.hs.get_notifier().on_new_replication_data)
|
txn.call_after(self.hs.get_notifier().on_new_replication_data)
|
||||||
|
|
||||||
if keys is not None:
|
if keys is not None:
|
||||||
|
@ -158,17 +165,18 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
|
||||||
|
|
||||||
self.db.simple_insert_txn(
|
self.db.simple_insert_txn(
|
||||||
txn,
|
txn,
|
||||||
table="cache_invalidation_stream",
|
table="cache_invalidation_stream_by_instance",
|
||||||
values={
|
values={
|
||||||
"stream_id": stream_id,
|
"stream_id": stream_id,
|
||||||
|
"instance_name": self._instance_name,
|
||||||
"cache_func": cache_name,
|
"cache_func": cache_name,
|
||||||
"keys": keys,
|
"keys": keys,
|
||||||
"invalidation_ts": self.clock.time_msec(),
|
"invalidation_ts": self.clock.time_msec(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_cache_stream_token(self):
|
def get_cache_stream_token(self, instance_name):
|
||||||
if self._cache_id_gen:
|
if self._cache_id_gen:
|
||||||
return self._cache_id_gen.get_current_token()
|
return self._cache_id_gen.get_current_token(instance_name)
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- We keep the old table here to enable us to roll back. It doesn't matter
|
||||||
|
-- that we have dropped all the data here.
|
||||||
|
TRUNCATE cache_invalidation_stream;
|
||||||
|
|
||||||
|
CREATE TABLE cache_invalidation_stream_by_instance (
|
||||||
|
stream_id BIGINT NOT NULL,
|
||||||
|
instance_name TEXT NOT NULL,
|
||||||
|
cache_func TEXT NOT NULL,
|
||||||
|
keys TEXT[],
|
||||||
|
invalidation_ts BIGINT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX cache_invalidation_stream_by_instance_id ON cache_invalidation_stream_by_instance(stream_id);
|
||||||
|
|
||||||
|
CREATE SEQUENCE cache_invalidation_stream_seq;
|
|
@ -29,6 +29,8 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Remember to update this number every time a change is made to database
|
# Remember to update this number every time a change is made to database
|
||||||
# schema files, so the users will be informed on server restarts.
|
# schema files, so the users will be informed on server restarts.
|
||||||
|
# XXX: If you're about to bump this to 59 (or higher) please create an update
|
||||||
|
# that drops the unused `cache_invalidation_stream` table, as per #7436!
|
||||||
SCHEMA_VERSION = 58
|
SCHEMA_VERSION = 58
|
||||||
|
|
||||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
|
Loading…
Reference in a new issue