mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-24 10:35:46 +03:00
Move catchup of replication streams to worker. (#7024)
This changes the replication protocol so that the server does not send down `RDATA` for rows that happened before the client connected. Instead, the server will send a `POSITION` and clients then query the database (or master out of band) to get up to date.
This commit is contained in:
parent
7bab642707
commit
4cff617df1
24 changed files with 635 additions and 487 deletions
1
changelog.d/7024.misc
Normal file
1
changelog.d/7024.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Move catchup of replication streams logic to worker.
|
|
@ -14,16 +14,16 @@ example flow would be (where '>' indicates master to worker and
|
||||||
'<' worker to master flows):
|
'<' worker to master flows):
|
||||||
|
|
||||||
> SERVER example.com
|
> SERVER example.com
|
||||||
< REPLICATE events 53
|
< REPLICATE
|
||||||
|
> POSITION events 53
|
||||||
> RDATA events 54 ["$foo1:bar.com", ...]
|
> RDATA events 54 ["$foo1:bar.com", ...]
|
||||||
> RDATA events 55 ["$foo4:bar.com", ...]
|
> RDATA events 55 ["$foo4:bar.com", ...]
|
||||||
|
|
||||||
The example shows the server accepting a new connection and sending its
|
The example shows the server accepting a new connection and sending its identity
|
||||||
identity with the `SERVER` command, followed by the client asking to
|
with the `SERVER` command, followed by the client server to respond with the
|
||||||
subscribe to the `events` stream from the token `53`. The server then
|
position of all streams. The server then periodically sends `RDATA` commands
|
||||||
periodically sends `RDATA` commands which have the format
|
which have the format `RDATA <stream_name> <token> <row>`, where the format of
|
||||||
`RDATA <stream_name> <token> <row>`, where the format of `<row>` is
|
`<row>` is defined by the individual streams.
|
||||||
defined by the individual streams.
|
|
||||||
|
|
||||||
Error reporting happens by either the client or server sending an ERROR
|
Error reporting happens by either the client or server sending an ERROR
|
||||||
command, and usually the connection will be closed.
|
command, and usually the connection will be closed.
|
||||||
|
@ -32,9 +32,6 @@ Since the protocol is a simple line based, its possible to manually
|
||||||
connect to the server using a tool like netcat. A few things should be
|
connect to the server using a tool like netcat. A few things should be
|
||||||
noted when manually using the protocol:
|
noted when manually using the protocol:
|
||||||
|
|
||||||
- When subscribing to a stream using `REPLICATE`, the special token
|
|
||||||
`NOW` can be used to get all future updates. The special stream name
|
|
||||||
`ALL` can be used with `NOW` to subscribe to all available streams.
|
|
||||||
- The federation stream is only available if federation sending has
|
- The federation stream is only available if federation sending has
|
||||||
been disabled on the main process.
|
been disabled on the main process.
|
||||||
- The server will only time connections out that have sent a `PING`
|
- The server will only time connections out that have sent a `PING`
|
||||||
|
@ -91,9 +88,7 @@ The client:
|
||||||
- Sends a `NAME` command, allowing the server to associate a human
|
- Sends a `NAME` command, allowing the server to associate a human
|
||||||
friendly name with the connection. This is optional.
|
friendly name with the connection. This is optional.
|
||||||
- Sends a `PING` as above
|
- Sends a `PING` as above
|
||||||
- For each stream the client wishes to subscribe to it sends a
|
- Sends a `REPLICATE` to get the current position of all streams.
|
||||||
`REPLICATE` with the `stream_name` and token it wants to subscribe
|
|
||||||
from.
|
|
||||||
- On receipt of a `SERVER` command, checks that the server name
|
- On receipt of a `SERVER` command, checks that the server name
|
||||||
matches the expected server name.
|
matches the expected server name.
|
||||||
|
|
||||||
|
@ -140,9 +135,7 @@ the wire:
|
||||||
> PING 1490197665618
|
> PING 1490197665618
|
||||||
< NAME synapse.app.appservice
|
< NAME synapse.app.appservice
|
||||||
< PING 1490197665618
|
< PING 1490197665618
|
||||||
< REPLICATE events 1
|
< REPLICATE
|
||||||
< REPLICATE backfill 1
|
|
||||||
< REPLICATE caches 1
|
|
||||||
> POSITION events 1
|
> POSITION events 1
|
||||||
> POSITION backfill 1
|
> POSITION backfill 1
|
||||||
> POSITION caches 1
|
> POSITION caches 1
|
||||||
|
@ -181,9 +174,9 @@ client (C):
|
||||||
|
|
||||||
#### POSITION (S)
|
#### POSITION (S)
|
||||||
|
|
||||||
The position of the stream has been updated. Sent to the client
|
On receipt of a POSITION command clients should check if they have missed any
|
||||||
after all missing updates for a stream have been sent to the client
|
updates, and if so then fetch them out of band. Sent in response to a
|
||||||
and they're now up to date.
|
REPLICATE command (but can happen at any time).
|
||||||
|
|
||||||
#### ERROR (S, C)
|
#### ERROR (S, C)
|
||||||
|
|
||||||
|
@ -199,20 +192,7 @@ client (C):
|
||||||
|
|
||||||
#### REPLICATE (C)
|
#### REPLICATE (C)
|
||||||
|
|
||||||
Asks the server to replicate a given stream. The syntax is:
|
Asks the server for the current position of all streams.
|
||||||
|
|
||||||
```
|
|
||||||
REPLICATE <stream_name> <token>
|
|
||||||
```
|
|
||||||
|
|
||||||
Where `<token>` may be either:
|
|
||||||
* a numeric stream_id to stream updates since (exclusive)
|
|
||||||
* `NOW` to stream all subsequent updates.
|
|
||||||
|
|
||||||
The `<stream_name>` is the name of a replication stream to subscribe
|
|
||||||
to (see [here](../synapse/replication/tcp/streams/_base.py) for a list
|
|
||||||
of streams). It can also be `ALL` to subscribe to all known streams,
|
|
||||||
in which case the `<token>` must be set to `NOW`.
|
|
||||||
|
|
||||||
#### USER_SYNC (C)
|
#### USER_SYNC (C)
|
||||||
|
|
||||||
|
|
|
@ -401,6 +401,9 @@ class GenericWorkerTyping(object):
|
||||||
self._room_serials[row.room_id] = token
|
self._room_serials[row.room_id] = token
|
||||||
self._room_typing[row.room_id] = row.user_ids
|
self._room_typing[row.room_id] = row.user_ids
|
||||||
|
|
||||||
|
def get_current_token(self) -> int:
|
||||||
|
return self._latest_room_serial
|
||||||
|
|
||||||
|
|
||||||
class GenericWorkerSlavedStore(
|
class GenericWorkerSlavedStore(
|
||||||
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
|
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
|
||||||
|
|
|
@ -499,4 +499,13 @@ class FederationSender(object):
|
||||||
self._get_per_destination_queue(destination).attempt_new_transaction()
|
self._get_per_destination_queue(destination).attempt_new_transaction()
|
||||||
|
|
||||||
def get_current_token(self) -> int:
|
def get_current_token(self) -> int:
|
||||||
|
# Dummy implementation for case where federation sender isn't offloaded
|
||||||
|
# to a worker.
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
async def get_replication_rows(
|
||||||
|
self, from_token, to_token, limit, federation_ack=None
|
||||||
|
):
|
||||||
|
# Dummy implementation for case where federation sender isn't offloaded
|
||||||
|
# to a worker.
|
||||||
|
return []
|
||||||
|
|
|
@ -21,6 +21,7 @@ from synapse.replication.http import (
|
||||||
membership,
|
membership,
|
||||||
register,
|
register,
|
||||||
send_event,
|
send_event,
|
||||||
|
streams,
|
||||||
)
|
)
|
||||||
|
|
||||||
REPLICATION_PREFIX = "/_synapse/replication"
|
REPLICATION_PREFIX = "/_synapse/replication"
|
||||||
|
@ -38,3 +39,4 @@ class ReplicationRestResource(JsonResource):
|
||||||
login.register_servlets(hs, self)
|
login.register_servlets(hs, self)
|
||||||
register.register_servlets(hs, self)
|
register.register_servlets(hs, self)
|
||||||
devices.register_servlets(hs, self)
|
devices.register_servlets(hs, self)
|
||||||
|
streams.register_servlets(hs, self)
|
||||||
|
|
78
synapse/replication/http/streams.py
Normal file
78
synapse/replication/http/streams.py
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.http.servlet import parse_integer
|
||||||
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicationGetStreamUpdates(ReplicationEndpoint):
|
||||||
|
"""Fetches stream updates from a server. Used for streams not persisted to
|
||||||
|
the database, e.g. typing notifications.
|
||||||
|
|
||||||
|
The API looks like:
|
||||||
|
|
||||||
|
GET /_synapse/replication/get_repl_stream_updates/events?from_token=0&to_token=10&limit=100
|
||||||
|
|
||||||
|
200 OK
|
||||||
|
|
||||||
|
{
|
||||||
|
updates: [ ... ],
|
||||||
|
upto_token: 10,
|
||||||
|
limited: False,
|
||||||
|
}
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
NAME = "get_repl_stream_updates"
|
||||||
|
PATH_ARGS = ("stream_name",)
|
||||||
|
METHOD = "GET"
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super().__init__(hs)
|
||||||
|
|
||||||
|
# We pull the streams from the replication steamer (if we try and make
|
||||||
|
# them ourselves we end up in an import loop).
|
||||||
|
self.streams = hs.get_replication_streamer().get_streams()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _serialize_payload(stream_name, from_token, upto_token, limit):
|
||||||
|
return {"from_token": from_token, "upto_token": upto_token, "limit": limit}
|
||||||
|
|
||||||
|
async def _handle_request(self, request, stream_name):
|
||||||
|
stream = self.streams.get(stream_name)
|
||||||
|
if stream is None:
|
||||||
|
raise SynapseError(400, "Unknown stream")
|
||||||
|
|
||||||
|
from_token = parse_integer(request, "from_token", required=True)
|
||||||
|
upto_token = parse_integer(request, "upto_token", required=True)
|
||||||
|
limit = parse_integer(request, "limit", required=True)
|
||||||
|
|
||||||
|
updates, upto_token, limited = await stream.get_updates_since(
|
||||||
|
from_token, upto_token, limit
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
200,
|
||||||
|
{"updates": updates, "upto_token": upto_token, "limited": limited},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def register_servlets(hs, http_server):
|
||||||
|
ReplicationGetStreamUpdates(hs).register(http_server)
|
|
@ -18,8 +18,10 @@ from typing import Dict, Optional
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage.data_stores.main.cache import (
|
||||||
from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
|
CURRENT_STATE_CACHE_NAME,
|
||||||
|
CacheInvalidationWorkerStore,
|
||||||
|
)
|
||||||
from synapse.storage.database import Database
|
from synapse.storage.database import Database
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
|
|
||||||
|
@ -35,7 +37,7 @@ def __func__(inp):
|
||||||
return inp.__func__
|
return inp.__func__
|
||||||
|
|
||||||
|
|
||||||
class BaseSlavedStore(SQLBaseStore):
|
class BaseSlavedStore(CacheInvalidationWorkerStore):
|
||||||
def __init__(self, database: Database, db_conn, hs):
|
def __init__(self, database: Database, db_conn, hs):
|
||||||
super(BaseSlavedStore, self).__init__(database, db_conn, hs)
|
super(BaseSlavedStore, self).__init__(database, db_conn, hs)
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
|
@ -60,6 +62,12 @@ class BaseSlavedStore(SQLBaseStore):
|
||||||
pos["caches"] = self._cache_id_gen.get_current_token()
|
pos["caches"] = self._cache_id_gen.get_current_token()
|
||||||
return pos
|
return pos
|
||||||
|
|
||||||
|
def get_cache_stream_token(self):
|
||||||
|
if self._cache_id_gen:
|
||||||
|
return self._cache_id_gen.get_current_token()
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
if stream_name == "caches":
|
if stream_name == "caches":
|
||||||
if self._cache_id_gen:
|
if self._cache_id_gen:
|
||||||
|
|
|
@ -33,6 +33,9 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
|
||||||
result["pushers"] = self._pushers_id_gen.get_current_token()
|
result["pushers"] = self._pushers_id_gen.get_current_token()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def get_pushers_stream_token(self):
|
||||||
|
return self._pushers_id_gen.get_current_token()
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
if stream_name == "pushers":
|
if stream_name == "pushers":
|
||||||
self._pushers_id_gen.advance(token)
|
self._pushers_id_gen.advance(token)
|
||||||
|
|
|
@ -55,6 +55,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
|
||||||
self.client_name = client_name
|
self.client_name = client_name
|
||||||
self.handler = handler
|
self.handler = handler
|
||||||
self.server_name = hs.config.server_name
|
self.server_name = hs.config.server_name
|
||||||
|
self.hs = hs
|
||||||
self._clock = hs.get_clock() # As self.clock is defined in super class
|
self._clock = hs.get_clock() # As self.clock is defined in super class
|
||||||
|
|
||||||
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
|
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
|
||||||
|
@ -65,7 +66,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
|
||||||
def buildProtocol(self, addr):
|
def buildProtocol(self, addr):
|
||||||
logger.info("Connected to replication: %r", addr)
|
logger.info("Connected to replication: %r", addr)
|
||||||
return ClientReplicationStreamProtocol(
|
return ClientReplicationStreamProtocol(
|
||||||
self.client_name, self.server_name, self._clock, self.handler
|
self.hs, self.client_name, self.server_name, self._clock, self.handler,
|
||||||
)
|
)
|
||||||
|
|
||||||
def clientConnectionLost(self, connector, reason):
|
def clientConnectionLost(self, connector, reason):
|
||||||
|
|
|
@ -136,8 +136,8 @@ class PositionCommand(Command):
|
||||||
"""Sent by the server to tell the client the stream postition without
|
"""Sent by the server to tell the client the stream postition without
|
||||||
needing to send an RDATA.
|
needing to send an RDATA.
|
||||||
|
|
||||||
Sent to the client after all missing updates for a stream have been sent
|
On receipt of a POSITION command clients should check if they have missed
|
||||||
to the client and they're now up to date.
|
any updates, and if so then fetch them out of band.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
NAME = "POSITION"
|
NAME = "POSITION"
|
||||||
|
@ -179,42 +179,24 @@ class NameCommand(Command):
|
||||||
|
|
||||||
|
|
||||||
class ReplicateCommand(Command):
|
class ReplicateCommand(Command):
|
||||||
"""Sent by the client to subscribe to the stream.
|
"""Sent by the client to subscribe to streams.
|
||||||
|
|
||||||
Format::
|
Format::
|
||||||
|
|
||||||
REPLICATE <stream_name> <token>
|
REPLICATE
|
||||||
|
|
||||||
Where <token> may be either:
|
|
||||||
* a numeric stream_id to stream updates from
|
|
||||||
* "NOW" to stream all subsequent updates.
|
|
||||||
|
|
||||||
The <stream_name> can be "ALL" to subscribe to all known streams, in which
|
|
||||||
case the <token> must be set to "NOW", i.e.::
|
|
||||||
|
|
||||||
REPLICATE ALL NOW
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
NAME = "REPLICATE"
|
NAME = "REPLICATE"
|
||||||
|
|
||||||
def __init__(self, stream_name, token):
|
def __init__(self):
|
||||||
self.stream_name = stream_name
|
pass
|
||||||
self.token = token
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_line(cls, line):
|
def from_line(cls, line):
|
||||||
stream_name, token = line.split(" ", 1)
|
return cls()
|
||||||
if token in ("NOW", "now"):
|
|
||||||
token = "NOW"
|
|
||||||
else:
|
|
||||||
token = int(token)
|
|
||||||
return cls(stream_name, token)
|
|
||||||
|
|
||||||
def to_line(self):
|
def to_line(self):
|
||||||
return " ".join((self.stream_name, str(self.token)))
|
return ""
|
||||||
|
|
||||||
def get_logcontext_id(self):
|
|
||||||
return "REPLICATE-" + self.stream_name
|
|
||||||
|
|
||||||
|
|
||||||
class UserSyncCommand(Command):
|
class UserSyncCommand(Command):
|
||||||
|
|
|
@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire::
|
||||||
> PING 1490197665618
|
> PING 1490197665618
|
||||||
< NAME synapse.app.appservice
|
< NAME synapse.app.appservice
|
||||||
< PING 1490197665618
|
< PING 1490197665618
|
||||||
< REPLICATE events 1
|
< REPLICATE
|
||||||
< REPLICATE backfill 1
|
|
||||||
< REPLICATE caches 1
|
|
||||||
> POSITION events 1
|
> POSITION events 1
|
||||||
> POSITION backfill 1
|
> POSITION backfill 1
|
||||||
> POSITION caches 1
|
> POSITION caches 1
|
||||||
|
@ -53,17 +51,15 @@ import fcntl
|
||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, DefaultDict, Dict, List, Set, Tuple
|
from typing import Any, DefaultDict, Dict, List, Set
|
||||||
|
|
||||||
from six import iteritems, iterkeys
|
from six import iteritems
|
||||||
|
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.protocols.basic import LineOnlyReceiver
|
from twisted.protocols.basic import LineOnlyReceiver
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
|
|
||||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
|
||||||
from synapse.metrics import LaterGauge
|
from synapse.metrics import LaterGauge
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.replication.tcp.commands import (
|
from synapse.replication.tcp.commands import (
|
||||||
|
@ -82,11 +78,16 @@ from synapse.replication.tcp.commands import (
|
||||||
SyncCommand,
|
SyncCommand,
|
||||||
UserSyncCommand,
|
UserSyncCommand,
|
||||||
)
|
)
|
||||||
from synapse.replication.tcp.streams import STREAMS_MAP
|
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
|
||||||
from synapse.types import Collection
|
from synapse.types import Collection
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
|
||||||
|
MYPY = False
|
||||||
|
if MYPY:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
|
||||||
connection_close_counter = Counter(
|
connection_close_counter = Counter(
|
||||||
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
|
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
|
||||||
)
|
)
|
||||||
|
@ -411,16 +412,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||||
self.server_name = server_name
|
self.server_name = server_name
|
||||||
self.streamer = streamer
|
self.streamer = streamer
|
||||||
|
|
||||||
# The streams the client has subscribed to and is up to date with
|
|
||||||
self.replication_streams = set() # type: Set[str]
|
|
||||||
|
|
||||||
# The streams the client is currently subscribing to.
|
|
||||||
self.connecting_streams = set() # type: Set[str]
|
|
||||||
|
|
||||||
# Map from stream name to list of updates to send once we've finished
|
|
||||||
# subscribing the client to the stream.
|
|
||||||
self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
|
|
||||||
|
|
||||||
def connectionMade(self):
|
def connectionMade(self):
|
||||||
self.send_command(ServerCommand(self.server_name))
|
self.send_command(ServerCommand(self.server_name))
|
||||||
BaseReplicationStreamProtocol.connectionMade(self)
|
BaseReplicationStreamProtocol.connectionMade(self)
|
||||||
|
@ -436,21 +427,10 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def on_REPLICATE(self, cmd):
|
async def on_REPLICATE(self, cmd):
|
||||||
stream_name = cmd.stream_name
|
# Subscribe to all streams we're publishing to.
|
||||||
token = cmd.token
|
for stream_name in self.streamer.streams_by_name:
|
||||||
|
current_token = self.streamer.get_stream_token(stream_name)
|
||||||
if stream_name == "ALL":
|
self.send_command(PositionCommand(stream_name, current_token))
|
||||||
# Subscribe to all streams we're publishing to.
|
|
||||||
deferreds = [
|
|
||||||
run_in_background(self.subscribe_to_stream, stream, token)
|
|
||||||
for stream in iterkeys(self.streamer.streams_by_name)
|
|
||||||
]
|
|
||||||
|
|
||||||
await make_deferred_yieldable(
|
|
||||||
defer.gatherResults(deferreds, consumeErrors=True)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
await self.subscribe_to_stream(stream_name, token)
|
|
||||||
|
|
||||||
async def on_FEDERATION_ACK(self, cmd):
|
async def on_FEDERATION_ACK(self, cmd):
|
||||||
self.streamer.federation_ack(cmd.token)
|
self.streamer.federation_ack(cmd.token)
|
||||||
|
@ -474,87 +454,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||||
cmd.last_seen,
|
cmd.last_seen,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def subscribe_to_stream(self, stream_name, token):
|
|
||||||
"""Subscribe the remote to a stream.
|
|
||||||
|
|
||||||
This invloves checking if they've missed anything and sending those
|
|
||||||
updates down if they have. During that time new updates for the stream
|
|
||||||
are queued and sent once we've sent down any missed updates.
|
|
||||||
"""
|
|
||||||
self.replication_streams.discard(stream_name)
|
|
||||||
self.connecting_streams.add(stream_name)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get missing updates
|
|
||||||
updates, current_token = await self.streamer.get_stream_updates(
|
|
||||||
stream_name, token
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send all the missing updates
|
|
||||||
for update in updates:
|
|
||||||
token, row = update[0], update[1]
|
|
||||||
self.send_command(RdataCommand(stream_name, token, row))
|
|
||||||
|
|
||||||
# We send a POSITION command to ensure that they have an up to
|
|
||||||
# date token (especially useful if we didn't send any updates
|
|
||||||
# above)
|
|
||||||
self.send_command(PositionCommand(stream_name, current_token))
|
|
||||||
|
|
||||||
# Now we can send any updates that came in while we were subscribing
|
|
||||||
pending_rdata = self.pending_rdata.pop(stream_name, [])
|
|
||||||
updates = []
|
|
||||||
for token, update in pending_rdata:
|
|
||||||
# If the token is null, it is part of a batch update. Batches
|
|
||||||
# are multiple updates that share a single token. To denote
|
|
||||||
# this, the token is set to None for all tokens in the batch
|
|
||||||
# except for the last. If we find a None token, we keep looking
|
|
||||||
# through tokens until we find one that is not None and then
|
|
||||||
# process all previous updates in the batch as if they had the
|
|
||||||
# final token.
|
|
||||||
if token is None:
|
|
||||||
# Store this update as part of a batch
|
|
||||||
updates.append(update)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if token <= current_token:
|
|
||||||
# This update or batch of updates is older than
|
|
||||||
# current_token, dismiss it
|
|
||||||
updates = []
|
|
||||||
continue
|
|
||||||
|
|
||||||
updates.append(update)
|
|
||||||
|
|
||||||
# Send all updates that are part of this batch with the
|
|
||||||
# found token
|
|
||||||
for update in updates:
|
|
||||||
self.send_command(RdataCommand(stream_name, token, update))
|
|
||||||
|
|
||||||
# Clear stored updates
|
|
||||||
updates = []
|
|
||||||
|
|
||||||
# They're now fully subscribed
|
|
||||||
self.replication_streams.add(stream_name)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("[%s] Failed to handle REPLICATE command", self.id())
|
|
||||||
self.send_error("failed to handle replicate: %r", e)
|
|
||||||
finally:
|
|
||||||
self.connecting_streams.discard(stream_name)
|
|
||||||
|
|
||||||
def stream_update(self, stream_name, token, data):
|
def stream_update(self, stream_name, token, data):
|
||||||
"""Called when a new update is available to stream to clients.
|
"""Called when a new update is available to stream to clients.
|
||||||
|
|
||||||
We need to check if the client is interested in the stream or not
|
We need to check if the client is interested in the stream or not
|
||||||
"""
|
"""
|
||||||
if stream_name in self.replication_streams:
|
self.send_command(RdataCommand(stream_name, token, data))
|
||||||
# The client is subscribed to the stream
|
|
||||||
self.send_command(RdataCommand(stream_name, token, data))
|
|
||||||
elif stream_name in self.connecting_streams:
|
|
||||||
# The client is being subscribed to the stream
|
|
||||||
logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
|
|
||||||
self.pending_rdata.setdefault(stream_name, []).append((token, data))
|
|
||||||
else:
|
|
||||||
# The client isn't subscribed
|
|
||||||
logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
|
|
||||||
|
|
||||||
def send_sync(self, data):
|
def send_sync(self, data):
|
||||||
self.send_command(SyncCommand(data))
|
self.send_command(SyncCommand(data))
|
||||||
|
@ -638,6 +543,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
hs: "HomeServer",
|
||||||
client_name: str,
|
client_name: str,
|
||||||
server_name: str,
|
server_name: str,
|
||||||
clock: Clock,
|
clock: Clock,
|
||||||
|
@ -649,22 +555,25 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||||
self.server_name = server_name
|
self.server_name = server_name
|
||||||
self.handler = handler
|
self.handler = handler
|
||||||
|
|
||||||
|
self.streams = {
|
||||||
|
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
|
||||||
|
} # type: Dict[str, Stream]
|
||||||
|
|
||||||
# Set of stream names that have been subscribe to, but haven't yet
|
# Set of stream names that have been subscribe to, but haven't yet
|
||||||
# caught up with. This is used to track when the client has been fully
|
# caught up with. This is used to track when the client has been fully
|
||||||
# connected to the remote.
|
# connected to the remote.
|
||||||
self.streams_connecting = set() # type: Set[str]
|
self.streams_connecting = set(STREAMS_MAP) # type: Set[str]
|
||||||
|
|
||||||
# Map of stream to batched updates. See RdataCommand for info on how
|
# Map of stream to batched updates. See RdataCommand for info on how
|
||||||
# batching works.
|
# batching works.
|
||||||
self.pending_batches = {} # type: Dict[str, Any]
|
self.pending_batches = {} # type: Dict[str, List[Any]]
|
||||||
|
|
||||||
def connectionMade(self):
|
def connectionMade(self):
|
||||||
self.send_command(NameCommand(self.client_name))
|
self.send_command(NameCommand(self.client_name))
|
||||||
BaseReplicationStreamProtocol.connectionMade(self)
|
BaseReplicationStreamProtocol.connectionMade(self)
|
||||||
|
|
||||||
# Once we've connected subscribe to the necessary streams
|
# Once we've connected subscribe to the necessary streams
|
||||||
for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
|
self.replicate()
|
||||||
self.replicate(stream_name, token)
|
|
||||||
|
|
||||||
# Tell the server if we have any users currently syncing (should only
|
# Tell the server if we have any users currently syncing (should only
|
||||||
# happen on synchrotrons)
|
# happen on synchrotrons)
|
||||||
|
@ -676,10 +585,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||||
# We've now finished connecting to so inform the client handler
|
# We've now finished connecting to so inform the client handler
|
||||||
self.handler.update_connection(self)
|
self.handler.update_connection(self)
|
||||||
|
|
||||||
# This will happen if we don't actually subscribe to any streams
|
|
||||||
if not self.streams_connecting:
|
|
||||||
self.handler.finished_connecting()
|
|
||||||
|
|
||||||
async def on_SERVER(self, cmd):
|
async def on_SERVER(self, cmd):
|
||||||
if cmd.data != self.server_name:
|
if cmd.data != self.server_name:
|
||||||
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
|
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
|
||||||
|
@ -697,7 +602,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if cmd.token is None:
|
if cmd.token is None or stream_name in self.streams_connecting:
|
||||||
# I.e. this is part of a batch of updates for this stream. Batch
|
# I.e. this is part of a batch of updates for this stream. Batch
|
||||||
# until we get an update for the stream with a non None token
|
# until we get an update for the stream with a non None token
|
||||||
self.pending_batches.setdefault(stream_name, []).append(row)
|
self.pending_batches.setdefault(stream_name, []).append(row)
|
||||||
|
@ -707,14 +612,55 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||||
rows.append(row)
|
rows.append(row)
|
||||||
await self.handler.on_rdata(stream_name, cmd.token, rows)
|
await self.handler.on_rdata(stream_name, cmd.token, rows)
|
||||||
|
|
||||||
async def on_POSITION(self, cmd):
|
async def on_POSITION(self, cmd: PositionCommand):
|
||||||
# When we get a `POSITION` command it means we've finished getting
|
stream = self.streams.get(cmd.stream_name)
|
||||||
# missing updates for the given stream, and are now up to date.
|
if not stream:
|
||||||
|
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Find where we previously streamed up to.
|
||||||
|
current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name)
|
||||||
|
if current_token is None:
|
||||||
|
logger.warning(
|
||||||
|
"Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Fetch all updates between then and now.
|
||||||
|
limited = True
|
||||||
|
while limited:
|
||||||
|
updates, current_token, limited = await stream.get_updates_since(
|
||||||
|
current_token, cmd.token
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the connection was closed underneath us, if so we bail
|
||||||
|
# rather than risk having concurrent catch ups going on.
|
||||||
|
if self.state == ConnectionStates.CLOSED:
|
||||||
|
return
|
||||||
|
|
||||||
|
if updates:
|
||||||
|
await self.handler.on_rdata(
|
||||||
|
cmd.stream_name,
|
||||||
|
current_token,
|
||||||
|
[stream.parse_row(update[1]) for update in updates],
|
||||||
|
)
|
||||||
|
|
||||||
|
# We've now caught up to position sent to us, notify handler.
|
||||||
|
await self.handler.on_position(cmd.stream_name, cmd.token)
|
||||||
|
|
||||||
self.streams_connecting.discard(cmd.stream_name)
|
self.streams_connecting.discard(cmd.stream_name)
|
||||||
if not self.streams_connecting:
|
if not self.streams_connecting:
|
||||||
self.handler.finished_connecting()
|
self.handler.finished_connecting()
|
||||||
|
|
||||||
await self.handler.on_position(cmd.stream_name, cmd.token)
|
# Check if the connection was closed underneath us, if so we bail
|
||||||
|
# rather than risk having concurrent catch ups going on.
|
||||||
|
if self.state == ConnectionStates.CLOSED:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Handle any RDATA that came in while we were catching up.
|
||||||
|
rows = self.pending_batches.pop(cmd.stream_name, [])
|
||||||
|
if rows:
|
||||||
|
await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows)
|
||||||
|
|
||||||
async def on_SYNC(self, cmd):
|
async def on_SYNC(self, cmd):
|
||||||
self.handler.on_sync(cmd.data)
|
self.handler.on_sync(cmd.data)
|
||||||
|
@ -722,22 +668,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||||
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
|
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
|
||||||
self.handler.on_remote_server_up(cmd.data)
|
self.handler.on_remote_server_up(cmd.data)
|
||||||
|
|
||||||
def replicate(self, stream_name, token):
|
def replicate(self):
|
||||||
"""Send the subscription request to the server
|
"""Send the subscription request to the server
|
||||||
"""
|
"""
|
||||||
if stream_name not in STREAMS_MAP:
|
logger.info("[%s] Subscribing to replication streams", self.id())
|
||||||
raise Exception("Invalid stream name %r" % (stream_name,))
|
|
||||||
|
|
||||||
logger.info(
|
self.send_command(ReplicateCommand())
|
||||||
"[%s] Subscribing to replication stream: %r from %r",
|
|
||||||
self.id(),
|
|
||||||
stream_name,
|
|
||||||
token,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.streams_connecting.add(stream_name)
|
|
||||||
|
|
||||||
self.send_command(ReplicateCommand(stream_name, token))
|
|
||||||
|
|
||||||
def on_connection_closed(self):
|
def on_connection_closed(self):
|
||||||
BaseReplicationStreamProtocol.on_connection_closed(self)
|
BaseReplicationStreamProtocol.on_connection_closed(self)
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from typing import Any, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from six import itervalues
|
from six import itervalues
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.util.metrics import Measure, measure_func
|
from synapse.util.metrics import Measure, measure_func
|
||||||
|
|
||||||
from .protocol import ServerReplicationStreamProtocol
|
from .protocol import ServerReplicationStreamProtocol
|
||||||
from .streams import STREAMS_MAP
|
from .streams import STREAMS_MAP, Stream
|
||||||
from .streams.federation import FederationStream
|
from .streams.federation import FederationStream
|
||||||
|
|
||||||
stream_updates_counter = Counter(
|
stream_updates_counter = Counter(
|
||||||
|
@ -52,7 +52,7 @@ class ReplicationStreamProtocolFactory(Factory):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.streamer = ReplicationStreamer(hs)
|
self.streamer = hs.get_replication_streamer()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.server_name = hs.config.server_name
|
self.server_name = hs.config.server_name
|
||||||
|
|
||||||
|
@ -133,6 +133,11 @@ class ReplicationStreamer(object):
|
||||||
for conn in self.connections:
|
for conn in self.connections:
|
||||||
conn.send_error("server shutting down")
|
conn.send_error("server shutting down")
|
||||||
|
|
||||||
|
def get_streams(self) -> Dict[str, Stream]:
|
||||||
|
"""Get a mapp from stream name to stream instance.
|
||||||
|
"""
|
||||||
|
return self.streams_by_name
|
||||||
|
|
||||||
def on_notifier_poke(self):
|
def on_notifier_poke(self):
|
||||||
"""Checks if there is actually any new data and sends it to the
|
"""Checks if there is actually any new data and sends it to the
|
||||||
connections if there are.
|
connections if there are.
|
||||||
|
@ -190,7 +195,8 @@ class ReplicationStreamer(object):
|
||||||
stream.current_token(),
|
stream.current_token(),
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
updates, current_token = await stream.get_updates()
|
updates, current_token, limited = await stream.get_updates()
|
||||||
|
self.pending_updates |= limited
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.info("Failed to handle stream %s", stream.NAME)
|
logger.info("Failed to handle stream %s", stream.NAME)
|
||||||
raise
|
raise
|
||||||
|
@ -226,8 +232,7 @@ class ReplicationStreamer(object):
|
||||||
self.pending_updates = False
|
self.pending_updates = False
|
||||||
self.is_looping = False
|
self.is_looping = False
|
||||||
|
|
||||||
@measure_func("repl.get_stream_updates")
|
def get_stream_token(self, stream_name):
|
||||||
async def get_stream_updates(self, stream_name, token):
|
|
||||||
"""For a given stream get all updates since token. This is called when
|
"""For a given stream get all updates since token. This is called when
|
||||||
a client first subscribes to a stream.
|
a client first subscribes to a stream.
|
||||||
"""
|
"""
|
||||||
|
@ -235,7 +240,7 @@ class ReplicationStreamer(object):
|
||||||
if not stream:
|
if not stream:
|
||||||
raise Exception("unknown stream %s", stream_name)
|
raise Exception("unknown stream %s", stream_name)
|
||||||
|
|
||||||
return await stream.get_updates_since(token)
|
return stream.current_token()
|
||||||
|
|
||||||
@measure_func("repl.federation_ack")
|
@measure_func("repl.federation_ack")
|
||||||
def federation_ack(self, token):
|
def federation_ack(self, token):
|
||||||
|
|
|
@ -24,6 +24,9 @@ Each stream is defined by the following information:
|
||||||
current_token: The function that returns the current token for the stream
|
current_token: The function that returns the current token for the stream
|
||||||
update_function: The function that returns a list of updates between two tokens
|
update_function: The function that returns a list of updates between two tokens
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Type
|
||||||
|
|
||||||
from synapse.replication.tcp.streams._base import (
|
from synapse.replication.tcp.streams._base import (
|
||||||
AccountDataStream,
|
AccountDataStream,
|
||||||
BackfillStream,
|
BackfillStream,
|
||||||
|
@ -35,6 +38,7 @@ from synapse.replication.tcp.streams._base import (
|
||||||
PushersStream,
|
PushersStream,
|
||||||
PushRulesStream,
|
PushRulesStream,
|
||||||
ReceiptsStream,
|
ReceiptsStream,
|
||||||
|
Stream,
|
||||||
TagAccountDataStream,
|
TagAccountDataStream,
|
||||||
ToDeviceStream,
|
ToDeviceStream,
|
||||||
TypingStream,
|
TypingStream,
|
||||||
|
@ -63,10 +67,12 @@ STREAMS_MAP = {
|
||||||
GroupServerStream,
|
GroupServerStream,
|
||||||
UserSignatureStream,
|
UserSignatureStream,
|
||||||
)
|
)
|
||||||
}
|
} # type: Dict[str, Type[Stream]]
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"STREAMS_MAP",
|
"STREAMS_MAP",
|
||||||
|
"Stream",
|
||||||
"BackfillStream",
|
"BackfillStream",
|
||||||
"PresenceStream",
|
"PresenceStream",
|
||||||
"TypingStream",
|
"TypingStream",
|
||||||
|
|
|
@ -14,13 +14,13 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import itertools
|
|
||||||
import logging
|
import logging
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Any, List, Optional, Tuple
|
from typing import Any, Awaitable, Callable, List, Optional, Tuple
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
from synapse.replication.http.streams import ReplicationGetStreamUpdates
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -29,6 +29,15 @@ logger = logging.getLogger(__name__)
|
||||||
MAX_EVENTS_BEHIND = 500000
|
MAX_EVENTS_BEHIND = 500000
|
||||||
|
|
||||||
|
|
||||||
|
# Some type aliases to make things a bit easier.
|
||||||
|
|
||||||
|
# A stream position token
|
||||||
|
Token = int
|
||||||
|
|
||||||
|
# A pair of position in stream and args used to create an instance of `ROW_TYPE`.
|
||||||
|
StreamRow = Tuple[Token, tuple]
|
||||||
|
|
||||||
|
|
||||||
class Stream(object):
|
class Stream(object):
|
||||||
"""Base class for the streams.
|
"""Base class for the streams.
|
||||||
|
|
||||||
|
@ -56,6 +65,7 @@ class Stream(object):
|
||||||
return cls.ROW_TYPE(*row)
|
return cls.ROW_TYPE(*row)
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
|
||||||
# The token from which we last asked for updates
|
# The token from which we last asked for updates
|
||||||
self.last_token = self.current_token()
|
self.last_token = self.current_token()
|
||||||
|
|
||||||
|
@ -65,61 +75,46 @@ class Stream(object):
|
||||||
"""
|
"""
|
||||||
self.last_token = self.current_token()
|
self.last_token = self.current_token()
|
||||||
|
|
||||||
async def get_updates(self):
|
async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
|
||||||
"""Gets all updates since the last time this function was called (or
|
"""Gets all updates since the last time this function was called (or
|
||||||
since the stream was constructed if it hadn't been called before).
|
since the stream was constructed if it hadn't been called before).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[Tuple[List[Tuple[int, Any]], int]:
|
A triplet `(updates, new_last_token, limited)`, where `updates` is
|
||||||
Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
|
a list of `(token, row)` entries, `new_last_token` is the new
|
||||||
list of ``(token, row)`` entries. ``row`` will be json-serialised and
|
position in stream, and `limited` is whether there are more updates
|
||||||
sent over the replication steam.
|
to fetch.
|
||||||
"""
|
"""
|
||||||
updates, current_token = await self.get_updates_since(self.last_token)
|
current_token = self.current_token()
|
||||||
|
updates, current_token, limited = await self.get_updates_since(
|
||||||
|
self.last_token, current_token
|
||||||
|
)
|
||||||
self.last_token = current_token
|
self.last_token = current_token
|
||||||
|
|
||||||
return updates, current_token
|
return updates, current_token, limited
|
||||||
|
|
||||||
async def get_updates_since(
|
async def get_updates_since(
|
||||||
self, from_token: int
|
self, from_token: Token, upto_token: Token, limit: int = 100
|
||||||
) -> Tuple[List[Tuple[int, JsonDict]], int]:
|
) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
|
||||||
"""Like get_updates except allows specifying from when we should
|
"""Like get_updates except allows specifying from when we should
|
||||||
stream updates
|
stream updates
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Resolves to a pair `(updates, new_last_token)`, where `updates` is
|
A triplet `(updates, new_last_token, limited)`, where `updates` is
|
||||||
a list of `(token, row)` entries and `new_last_token` is the new
|
a list of `(token, row)` entries, `new_last_token` is the new
|
||||||
position in stream.
|
position in stream, and `limited` is whether there are more updates
|
||||||
|
to fetch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if from_token in ("NOW", "now"):
|
|
||||||
return [], self.current_token()
|
|
||||||
|
|
||||||
current_token = self.current_token()
|
|
||||||
|
|
||||||
from_token = int(from_token)
|
from_token = int(from_token)
|
||||||
|
|
||||||
if from_token == current_token:
|
if from_token == upto_token:
|
||||||
return [], current_token
|
return [], upto_token, False
|
||||||
|
|
||||||
rows = await self.update_function(
|
updates, upto_token, limited = await self.update_function(
|
||||||
from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
|
from_token, upto_token, limit=limit,
|
||||||
)
|
)
|
||||||
|
return updates, upto_token, limited
|
||||||
# never turn more than MAX_EVENTS_BEHIND + 1 into updates.
|
|
||||||
rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
|
|
||||||
|
|
||||||
updates = [(row[0], row[1:]) for row in rows]
|
|
||||||
|
|
||||||
# check we didn't get more rows than the limit.
|
|
||||||
# doing it like this allows the update_function to be a generator.
|
|
||||||
if len(updates) >= MAX_EVENTS_BEHIND:
|
|
||||||
raise Exception("stream %s has fallen behind" % (self.NAME))
|
|
||||||
|
|
||||||
# The update function didn't hit the limit, so we must have got all
|
|
||||||
# the updates to `current_token`, and can return that as our new
|
|
||||||
# stream position.
|
|
||||||
return updates, current_token
|
|
||||||
|
|
||||||
def current_token(self):
|
def current_token(self):
|
||||||
"""Gets the current token of the underlying streams. Should be provided
|
"""Gets the current token of the underlying streams. Should be provided
|
||||||
|
@ -141,6 +136,48 @@ class Stream(object):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
def db_query_to_update_function(
|
||||||
|
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
|
||||||
|
) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
|
||||||
|
"""Wraps a db query function which returns a list of rows to make it
|
||||||
|
suitable for use as an `update_function` for the Stream class
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def update_function(from_token, upto_token, limit):
|
||||||
|
rows = await query_function(from_token, upto_token, limit)
|
||||||
|
updates = [(row[0], row[1:]) for row in rows]
|
||||||
|
limited = False
|
||||||
|
if len(updates) == limit:
|
||||||
|
upto_token = rows[-1][0]
|
||||||
|
limited = True
|
||||||
|
|
||||||
|
return updates, upto_token, limited
|
||||||
|
|
||||||
|
return update_function
|
||||||
|
|
||||||
|
|
||||||
|
def make_http_update_function(
|
||||||
|
hs, stream_name: str
|
||||||
|
) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
|
||||||
|
"""Makes a suitable function for use as an `update_function` that queries
|
||||||
|
the master process for updates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
client = ReplicationGetStreamUpdates.make_client(hs)
|
||||||
|
|
||||||
|
async def update_function(
|
||||||
|
from_token: int, upto_token: int, limit: int
|
||||||
|
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||||
|
return await client(
|
||||||
|
stream_name=stream_name,
|
||||||
|
from_token=from_token,
|
||||||
|
upto_token=upto_token,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
return update_function
|
||||||
|
|
||||||
|
|
||||||
class BackfillStream(Stream):
|
class BackfillStream(Stream):
|
||||||
"""We fetched some old events and either we had never seen that event before
|
"""We fetched some old events and either we had never seen that event before
|
||||||
or it went from being an outlier to not.
|
or it went from being an outlier to not.
|
||||||
|
@ -164,7 +201,7 @@ class BackfillStream(Stream):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
self.current_token = store.get_current_backfill_token # type: ignore
|
self.current_token = store.get_current_backfill_token # type: ignore
|
||||||
self.update_function = store.get_all_new_backfill_event_rows # type: ignore
|
self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore
|
||||||
|
|
||||||
super(BackfillStream, self).__init__(hs)
|
super(BackfillStream, self).__init__(hs)
|
||||||
|
|
||||||
|
@ -190,8 +227,15 @@ class PresenceStream(Stream):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
presence_handler = hs.get_presence_handler()
|
presence_handler = hs.get_presence_handler()
|
||||||
|
|
||||||
|
self._is_worker = hs.config.worker_app is not None
|
||||||
|
|
||||||
self.current_token = store.get_current_presence_token # type: ignore
|
self.current_token = store.get_current_presence_token # type: ignore
|
||||||
self.update_function = presence_handler.get_all_presence_updates # type: ignore
|
|
||||||
|
if hs.config.worker_app is None:
|
||||||
|
self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore
|
||||||
|
else:
|
||||||
|
# Query master process
|
||||||
|
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
|
||||||
|
|
||||||
super(PresenceStream, self).__init__(hs)
|
super(PresenceStream, self).__init__(hs)
|
||||||
|
|
||||||
|
@ -208,7 +252,12 @@ class TypingStream(Stream):
|
||||||
typing_handler = hs.get_typing_handler()
|
typing_handler = hs.get_typing_handler()
|
||||||
|
|
||||||
self.current_token = typing_handler.get_current_token # type: ignore
|
self.current_token = typing_handler.get_current_token # type: ignore
|
||||||
self.update_function = typing_handler.get_all_typing_updates # type: ignore
|
|
||||||
|
if hs.config.worker_app is None:
|
||||||
|
self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore
|
||||||
|
else:
|
||||||
|
# Query master process
|
||||||
|
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
|
||||||
|
|
||||||
super(TypingStream, self).__init__(hs)
|
super(TypingStream, self).__init__(hs)
|
||||||
|
|
||||||
|
@ -232,7 +281,7 @@ class ReceiptsStream(Stream):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
|
|
||||||
self.current_token = store.get_max_receipt_stream_id # type: ignore
|
self.current_token = store.get_max_receipt_stream_id # type: ignore
|
||||||
self.update_function = store.get_all_updated_receipts # type: ignore
|
self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore
|
||||||
|
|
||||||
super(ReceiptsStream, self).__init__(hs)
|
super(ReceiptsStream, self).__init__(hs)
|
||||||
|
|
||||||
|
@ -256,7 +305,13 @@ class PushRulesStream(Stream):
|
||||||
|
|
||||||
async def update_function(self, from_token, to_token, limit):
|
async def update_function(self, from_token, to_token, limit):
|
||||||
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
|
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
|
||||||
return [(row[0], row[2]) for row in rows]
|
|
||||||
|
limited = False
|
||||||
|
if len(rows) == limit:
|
||||||
|
to_token = rows[-1][0]
|
||||||
|
limited = True
|
||||||
|
|
||||||
|
return [(row[0], (row[2],)) for row in rows], to_token, limited
|
||||||
|
|
||||||
|
|
||||||
class PushersStream(Stream):
|
class PushersStream(Stream):
|
||||||
|
@ -275,7 +330,7 @@ class PushersStream(Stream):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
|
|
||||||
self.current_token = store.get_pushers_stream_token # type: ignore
|
self.current_token = store.get_pushers_stream_token # type: ignore
|
||||||
self.update_function = store.get_all_updated_pushers_rows # type: ignore
|
self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore
|
||||||
|
|
||||||
super(PushersStream, self).__init__(hs)
|
super(PushersStream, self).__init__(hs)
|
||||||
|
|
||||||
|
@ -307,7 +362,7 @@ class CachesStream(Stream):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
|
|
||||||
self.current_token = store.get_cache_stream_token # type: ignore
|
self.current_token = store.get_cache_stream_token # type: ignore
|
||||||
self.update_function = store.get_all_updated_caches # type: ignore
|
self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore
|
||||||
|
|
||||||
super(CachesStream, self).__init__(hs)
|
super(CachesStream, self).__init__(hs)
|
||||||
|
|
||||||
|
@ -333,7 +388,7 @@ class PublicRoomsStream(Stream):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
|
|
||||||
self.current_token = store.get_current_public_room_stream_id # type: ignore
|
self.current_token = store.get_current_public_room_stream_id # type: ignore
|
||||||
self.update_function = store.get_all_new_public_rooms # type: ignore
|
self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore
|
||||||
|
|
||||||
super(PublicRoomsStream, self).__init__(hs)
|
super(PublicRoomsStream, self).__init__(hs)
|
||||||
|
|
||||||
|
@ -354,7 +409,7 @@ class DeviceListsStream(Stream):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
|
|
||||||
self.current_token = store.get_device_stream_token # type: ignore
|
self.current_token = store.get_device_stream_token # type: ignore
|
||||||
self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore
|
self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore
|
||||||
|
|
||||||
super(DeviceListsStream, self).__init__(hs)
|
super(DeviceListsStream, self).__init__(hs)
|
||||||
|
|
||||||
|
@ -372,7 +427,7 @@ class ToDeviceStream(Stream):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
|
|
||||||
self.current_token = store.get_to_device_stream_token # type: ignore
|
self.current_token = store.get_to_device_stream_token # type: ignore
|
||||||
self.update_function = store.get_all_new_device_messages # type: ignore
|
self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore
|
||||||
|
|
||||||
super(ToDeviceStream, self).__init__(hs)
|
super(ToDeviceStream, self).__init__(hs)
|
||||||
|
|
||||||
|
@ -392,7 +447,7 @@ class TagAccountDataStream(Stream):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
|
|
||||||
self.current_token = store.get_max_account_data_stream_id # type: ignore
|
self.current_token = store.get_max_account_data_stream_id # type: ignore
|
||||||
self.update_function = store.get_all_updated_tags # type: ignore
|
self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore
|
||||||
|
|
||||||
super(TagAccountDataStream, self).__init__(hs)
|
super(TagAccountDataStream, self).__init__(hs)
|
||||||
|
|
||||||
|
@ -412,10 +467,11 @@ class AccountDataStream(Stream):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
self.current_token = self.store.get_max_account_data_stream_id # type: ignore
|
self.current_token = self.store.get_max_account_data_stream_id # type: ignore
|
||||||
|
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
|
||||||
|
|
||||||
super(AccountDataStream, self).__init__(hs)
|
super(AccountDataStream, self).__init__(hs)
|
||||||
|
|
||||||
async def update_function(self, from_token, to_token, limit):
|
async def _update_function(self, from_token, to_token, limit):
|
||||||
global_results, room_results = await self.store.get_all_updated_account_data(
|
global_results, room_results = await self.store.get_all_updated_account_data(
|
||||||
from_token, from_token, to_token, limit
|
from_token, from_token, to_token, limit
|
||||||
)
|
)
|
||||||
|
@ -442,7 +498,7 @@ class GroupServerStream(Stream):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
|
|
||||||
self.current_token = store.get_group_stream_token # type: ignore
|
self.current_token = store.get_group_stream_token # type: ignore
|
||||||
self.update_function = store.get_all_groups_changes # type: ignore
|
self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore
|
||||||
|
|
||||||
super(GroupServerStream, self).__init__(hs)
|
super(GroupServerStream, self).__init__(hs)
|
||||||
|
|
||||||
|
@ -460,6 +516,6 @@ class UserSignatureStream(Stream):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
|
|
||||||
self.current_token = store.get_device_stream_token # type: ignore
|
self.current_token = store.get_device_stream_token # type: ignore
|
||||||
self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore
|
self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore
|
||||||
|
|
||||||
super(UserSignatureStream, self).__init__(hs)
|
super(UserSignatureStream, self).__init__(hs)
|
||||||
|
|
|
@ -19,7 +19,7 @@ from typing import Tuple, Type
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from ._base import Stream
|
from ._base import Stream, db_query_to_update_function
|
||||||
|
|
||||||
|
|
||||||
"""Handling of the 'events' replication stream
|
"""Handling of the 'events' replication stream
|
||||||
|
@ -117,10 +117,11 @@ class EventsStream(Stream):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self._store = hs.get_datastore()
|
self._store = hs.get_datastore()
|
||||||
self.current_token = self._store.get_current_events_token # type: ignore
|
self.current_token = self._store.get_current_events_token # type: ignore
|
||||||
|
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
|
||||||
|
|
||||||
super(EventsStream, self).__init__(hs)
|
super(EventsStream, self).__init__(hs)
|
||||||
|
|
||||||
async def update_function(self, from_token, current_token, limit=None):
|
async def _update_function(self, from_token, current_token, limit=None):
|
||||||
event_rows = await self._store.get_all_new_forward_event_rows(
|
event_rows = await self._store.get_all_new_forward_event_rows(
|
||||||
from_token, current_token, limit
|
from_token, current_token, limit
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,7 +15,9 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
from ._base import Stream
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
|
||||||
|
|
||||||
|
|
||||||
class FederationStream(Stream):
|
class FederationStream(Stream):
|
||||||
|
@ -33,11 +35,18 @@ class FederationStream(Stream):
|
||||||
|
|
||||||
NAME = "federation"
|
NAME = "federation"
|
||||||
ROW_TYPE = FederationStreamRow
|
ROW_TYPE = FederationStreamRow
|
||||||
|
_QUERY_MASTER = True
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
federation_sender = hs.get_federation_sender()
|
# Not all synapse instances will have a federation sender instance,
|
||||||
|
# whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
|
||||||
self.current_token = federation_sender.get_current_token # type: ignore
|
# so we stub the stream out when that is the case.
|
||||||
self.update_function = federation_sender.get_replication_rows # type: ignore
|
if hs.config.worker_app is None or hs.should_send_federation():
|
||||||
|
federation_sender = hs.get_federation_sender()
|
||||||
|
self.current_token = federation_sender.get_current_token # type: ignore
|
||||||
|
self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore
|
||||||
|
else:
|
||||||
|
self.current_token = lambda: 0 # type: ignore
|
||||||
|
self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore
|
||||||
|
|
||||||
super(FederationStream, self).__init__(hs)
|
super(FederationStream, self).__init__(hs)
|
||||||
|
|
|
@ -85,6 +85,7 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||||
from synapse.notifier import Notifier
|
from synapse.notifier import Notifier
|
||||||
from synapse.push.action_generator import ActionGenerator
|
from synapse.push.action_generator import ActionGenerator
|
||||||
from synapse.push.pusherpool import PusherPool
|
from synapse.push.pusherpool import PusherPool
|
||||||
|
from synapse.replication.tcp.resource import ReplicationStreamer
|
||||||
from synapse.rest.media.v1.media_repository import (
|
from synapse.rest.media.v1.media_repository import (
|
||||||
MediaRepository,
|
MediaRepository,
|
||||||
MediaRepositoryResource,
|
MediaRepositoryResource,
|
||||||
|
@ -199,6 +200,7 @@ class HomeServer(object):
|
||||||
"saml_handler",
|
"saml_handler",
|
||||||
"event_client_serializer",
|
"event_client_serializer",
|
||||||
"storage",
|
"storage",
|
||||||
|
"replication_streamer",
|
||||||
]
|
]
|
||||||
|
|
||||||
REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
|
REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
|
||||||
|
@ -536,6 +538,9 @@ class HomeServer(object):
|
||||||
def build_storage(self) -> Storage:
|
def build_storage(self) -> Storage:
|
||||||
return Storage(self, self.datastores)
|
return Storage(self, self.datastores)
|
||||||
|
|
||||||
|
def build_replication_streamer(self) -> ReplicationStreamer:
|
||||||
|
return ReplicationStreamer(self)
|
||||||
|
|
||||||
def remove_pusher(self, app_id, push_key, user_id):
|
def remove_pusher(self, app_id, push_key, user_id):
|
||||||
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,29 @@ logger = logging.getLogger(__name__)
|
||||||
CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
|
CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
|
||||||
|
|
||||||
|
|
||||||
class CacheInvalidationStore(SQLBaseStore):
|
class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||||
|
def get_all_updated_caches(self, last_id, current_id, limit):
|
||||||
|
if last_id == current_id:
|
||||||
|
return defer.succeed([])
|
||||||
|
|
||||||
|
def get_all_updated_caches_txn(txn):
|
||||||
|
# We purposefully don't bound by the current token, as we want to
|
||||||
|
# send across cache invalidations as quickly as possible. Cache
|
||||||
|
# invalidations are idempotent, so duplicates are fine.
|
||||||
|
sql = (
|
||||||
|
"SELECT stream_id, cache_func, keys, invalidation_ts"
|
||||||
|
" FROM cache_invalidation_stream"
|
||||||
|
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (last_id, limit))
|
||||||
|
return txn.fetchall()
|
||||||
|
|
||||||
|
return self.db.runInteraction(
|
||||||
|
"get_all_updated_caches", get_all_updated_caches_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CacheInvalidationStore(CacheInvalidationWorkerStore):
|
||||||
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
|
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
|
||||||
"""Invalidates the cache and adds it to the cache stream so slaves
|
"""Invalidates the cache and adds it to the cache stream so slaves
|
||||||
will know to invalidate their caches.
|
will know to invalidate their caches.
|
||||||
|
@ -145,26 +167,6 @@ class CacheInvalidationStore(SQLBaseStore):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_all_updated_caches(self, last_id, current_id, limit):
|
|
||||||
if last_id == current_id:
|
|
||||||
return defer.succeed([])
|
|
||||||
|
|
||||||
def get_all_updated_caches_txn(txn):
|
|
||||||
# We purposefully don't bound by the current token, as we want to
|
|
||||||
# send across cache invalidations as quickly as possible. Cache
|
|
||||||
# invalidations are idempotent, so duplicates are fine.
|
|
||||||
sql = (
|
|
||||||
"SELECT stream_id, cache_func, keys, invalidation_ts"
|
|
||||||
" FROM cache_invalidation_stream"
|
|
||||||
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
|
|
||||||
)
|
|
||||||
txn.execute(sql, (last_id, limit))
|
|
||||||
return txn.fetchall()
|
|
||||||
|
|
||||||
return self.db.runInteraction(
|
|
||||||
"get_all_updated_caches", get_all_updated_caches_txn
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_cache_stream_token(self):
|
def get_cache_stream_token(self):
|
||||||
if self._cache_id_gen:
|
if self._cache_id_gen:
|
||||||
return self._cache_id_gen.get_current_token()
|
return self._cache_id_gen.get_current_token()
|
||||||
|
|
|
@ -207,6 +207,50 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||||
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
|
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_all_new_device_messages(self, last_pos, current_pos, limit):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
last_pos(int):
|
||||||
|
current_pos(int):
|
||||||
|
limit(int):
|
||||||
|
Returns:
|
||||||
|
A deferred list of rows from the device inbox
|
||||||
|
"""
|
||||||
|
if last_pos == current_pos:
|
||||||
|
return defer.succeed([])
|
||||||
|
|
||||||
|
def get_all_new_device_messages_txn(txn):
|
||||||
|
# We limit like this as we might have multiple rows per stream_id, and
|
||||||
|
# we want to make sure we always get all entries for any stream_id
|
||||||
|
# we return.
|
||||||
|
upper_pos = min(current_pos, last_pos + limit)
|
||||||
|
sql = (
|
||||||
|
"SELECT max(stream_id), user_id"
|
||||||
|
" FROM device_inbox"
|
||||||
|
" WHERE ? < stream_id AND stream_id <= ?"
|
||||||
|
" GROUP BY user_id"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (last_pos, upper_pos))
|
||||||
|
rows = txn.fetchall()
|
||||||
|
|
||||||
|
sql = (
|
||||||
|
"SELECT max(stream_id), destination"
|
||||||
|
" FROM device_federation_outbox"
|
||||||
|
" WHERE ? < stream_id AND stream_id <= ?"
|
||||||
|
" GROUP BY destination"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (last_pos, upper_pos))
|
||||||
|
rows.extend(txn)
|
||||||
|
|
||||||
|
# Order by ascending stream ordering
|
||||||
|
rows.sort()
|
||||||
|
|
||||||
|
return rows
|
||||||
|
|
||||||
|
return self.db.runInteraction(
|
||||||
|
"get_all_new_device_messages", get_all_new_device_messages_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
|
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
|
||||||
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
|
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
|
||||||
|
@ -411,47 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
|
||||||
rows.append((user_id, device_id, stream_id, message_json))
|
rows.append((user_id, device_id, stream_id, message_json))
|
||||||
|
|
||||||
txn.executemany(sql, rows)
|
txn.executemany(sql, rows)
|
||||||
|
|
||||||
def get_all_new_device_messages(self, last_pos, current_pos, limit):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
last_pos(int):
|
|
||||||
current_pos(int):
|
|
||||||
limit(int):
|
|
||||||
Returns:
|
|
||||||
A deferred list of rows from the device inbox
|
|
||||||
"""
|
|
||||||
if last_pos == current_pos:
|
|
||||||
return defer.succeed([])
|
|
||||||
|
|
||||||
def get_all_new_device_messages_txn(txn):
|
|
||||||
# We limit like this as we might have multiple rows per stream_id, and
|
|
||||||
# we want to make sure we always get all entries for any stream_id
|
|
||||||
# we return.
|
|
||||||
upper_pos = min(current_pos, last_pos + limit)
|
|
||||||
sql = (
|
|
||||||
"SELECT max(stream_id), user_id"
|
|
||||||
" FROM device_inbox"
|
|
||||||
" WHERE ? < stream_id AND stream_id <= ?"
|
|
||||||
" GROUP BY user_id"
|
|
||||||
)
|
|
||||||
txn.execute(sql, (last_pos, upper_pos))
|
|
||||||
rows = txn.fetchall()
|
|
||||||
|
|
||||||
sql = (
|
|
||||||
"SELECT max(stream_id), destination"
|
|
||||||
" FROM device_federation_outbox"
|
|
||||||
" WHERE ? < stream_id AND stream_id <= ?"
|
|
||||||
" GROUP BY destination"
|
|
||||||
)
|
|
||||||
txn.execute(sql, (last_pos, upper_pos))
|
|
||||||
rows.extend(txn)
|
|
||||||
|
|
||||||
# Order by ascending stream ordering
|
|
||||||
rows.sort()
|
|
||||||
|
|
||||||
return rows
|
|
||||||
|
|
||||||
return self.db.runInteraction(
|
|
||||||
"get_all_new_device_messages", get_all_new_device_messages_txn
|
|
||||||
)
|
|
||||||
|
|
|
@ -1267,104 +1267,6 @@ class EventsStore(
|
||||||
ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
|
ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def get_current_backfill_token(self):
|
|
||||||
"""The current minimum token that backfilled events have reached"""
|
|
||||||
return -self._backfill_id_gen.get_current_token()
|
|
||||||
|
|
||||||
def get_current_events_token(self):
|
|
||||||
"""The current maximum token that events have reached"""
|
|
||||||
return self._stream_id_gen.get_current_token()
|
|
||||||
|
|
||||||
def get_all_new_forward_event_rows(self, last_id, current_id, limit):
|
|
||||||
if last_id == current_id:
|
|
||||||
return defer.succeed([])
|
|
||||||
|
|
||||||
def get_all_new_forward_event_rows(txn):
|
|
||||||
sql = (
|
|
||||||
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
|
|
||||||
" state_key, redacts, relates_to_id"
|
|
||||||
" FROM events AS e"
|
|
||||||
" LEFT JOIN redactions USING (event_id)"
|
|
||||||
" LEFT JOIN state_events USING (event_id)"
|
|
||||||
" LEFT JOIN event_relations USING (event_id)"
|
|
||||||
" WHERE ? < stream_ordering AND stream_ordering <= ?"
|
|
||||||
" ORDER BY stream_ordering ASC"
|
|
||||||
" LIMIT ?"
|
|
||||||
)
|
|
||||||
txn.execute(sql, (last_id, current_id, limit))
|
|
||||||
new_event_updates = txn.fetchall()
|
|
||||||
|
|
||||||
if len(new_event_updates) == limit:
|
|
||||||
upper_bound = new_event_updates[-1][0]
|
|
||||||
else:
|
|
||||||
upper_bound = current_id
|
|
||||||
|
|
||||||
sql = (
|
|
||||||
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
|
|
||||||
" state_key, redacts, relates_to_id"
|
|
||||||
" FROM events AS e"
|
|
||||||
" INNER JOIN ex_outlier_stream USING (event_id)"
|
|
||||||
" LEFT JOIN redactions USING (event_id)"
|
|
||||||
" LEFT JOIN state_events USING (event_id)"
|
|
||||||
" LEFT JOIN event_relations USING (event_id)"
|
|
||||||
" WHERE ? < event_stream_ordering"
|
|
||||||
" AND event_stream_ordering <= ?"
|
|
||||||
" ORDER BY event_stream_ordering DESC"
|
|
||||||
)
|
|
||||||
txn.execute(sql, (last_id, upper_bound))
|
|
||||||
new_event_updates.extend(txn)
|
|
||||||
|
|
||||||
return new_event_updates
|
|
||||||
|
|
||||||
return self.db.runInteraction(
|
|
||||||
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
|
|
||||||
if last_id == current_id:
|
|
||||||
return defer.succeed([])
|
|
||||||
|
|
||||||
def get_all_new_backfill_event_rows(txn):
|
|
||||||
sql = (
|
|
||||||
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
|
|
||||||
" state_key, redacts, relates_to_id"
|
|
||||||
" FROM events AS e"
|
|
||||||
" LEFT JOIN redactions USING (event_id)"
|
|
||||||
" LEFT JOIN state_events USING (event_id)"
|
|
||||||
" LEFT JOIN event_relations USING (event_id)"
|
|
||||||
" WHERE ? > stream_ordering AND stream_ordering >= ?"
|
|
||||||
" ORDER BY stream_ordering ASC"
|
|
||||||
" LIMIT ?"
|
|
||||||
)
|
|
||||||
txn.execute(sql, (-last_id, -current_id, limit))
|
|
||||||
new_event_updates = txn.fetchall()
|
|
||||||
|
|
||||||
if len(new_event_updates) == limit:
|
|
||||||
upper_bound = new_event_updates[-1][0]
|
|
||||||
else:
|
|
||||||
upper_bound = current_id
|
|
||||||
|
|
||||||
sql = (
|
|
||||||
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
|
|
||||||
" state_key, redacts, relates_to_id"
|
|
||||||
" FROM events AS e"
|
|
||||||
" INNER JOIN ex_outlier_stream USING (event_id)"
|
|
||||||
" LEFT JOIN redactions USING (event_id)"
|
|
||||||
" LEFT JOIN state_events USING (event_id)"
|
|
||||||
" LEFT JOIN event_relations USING (event_id)"
|
|
||||||
" WHERE ? > event_stream_ordering"
|
|
||||||
" AND event_stream_ordering >= ?"
|
|
||||||
" ORDER BY event_stream_ordering DESC"
|
|
||||||
)
|
|
||||||
txn.execute(sql, (-last_id, -upper_bound))
|
|
||||||
new_event_updates.extend(txn.fetchall())
|
|
||||||
|
|
||||||
return new_event_updates
|
|
||||||
|
|
||||||
return self.db.runInteraction(
|
|
||||||
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
|
|
||||||
)
|
|
||||||
|
|
||||||
@cached(num_args=5, max_entries=10)
|
@cached(num_args=5, max_entries=10)
|
||||||
def get_all_new_events(
|
def get_all_new_events(
|
||||||
self,
|
self,
|
||||||
|
@ -1850,22 +1752,6 @@ class EventsStore(
|
||||||
|
|
||||||
return (int(res["topological_ordering"]), int(res["stream_ordering"]))
|
return (int(res["topological_ordering"]), int(res["stream_ordering"]))
|
||||||
|
|
||||||
def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
|
|
||||||
def get_all_updated_current_state_deltas_txn(txn):
|
|
||||||
sql = """
|
|
||||||
SELECT stream_id, room_id, type, state_key, event_id
|
|
||||||
FROM current_state_delta_stream
|
|
||||||
WHERE ? < stream_id AND stream_id <= ?
|
|
||||||
ORDER BY stream_id ASC LIMIT ?
|
|
||||||
"""
|
|
||||||
txn.execute(sql, (from_token, to_token, limit))
|
|
||||||
return txn.fetchall()
|
|
||||||
|
|
||||||
return self.db.runInteraction(
|
|
||||||
"get_all_updated_current_state_deltas",
|
|
||||||
get_all_updated_current_state_deltas_txn,
|
|
||||||
)
|
|
||||||
|
|
||||||
def insert_labels_for_event_txn(
|
def insert_labels_for_event_txn(
|
||||||
self, txn, event_id, labels, room_id, topological_ordering
|
self, txn, event_id, labels, room_id, topological_ordering
|
||||||
):
|
):
|
||||||
|
|
|
@ -963,3 +963,117 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
complexity_v1 = round(state_events / 500, 2)
|
complexity_v1 = round(state_events / 500, 2)
|
||||||
|
|
||||||
return {"v1": complexity_v1}
|
return {"v1": complexity_v1}
|
||||||
|
|
||||||
|
def get_current_backfill_token(self):
|
||||||
|
"""The current minimum token that backfilled events have reached"""
|
||||||
|
return -self._backfill_id_gen.get_current_token()
|
||||||
|
|
||||||
|
def get_current_events_token(self):
|
||||||
|
"""The current maximum token that events have reached"""
|
||||||
|
return self._stream_id_gen.get_current_token()
|
||||||
|
|
||||||
|
def get_all_new_forward_event_rows(self, last_id, current_id, limit):
|
||||||
|
if last_id == current_id:
|
||||||
|
return defer.succeed([])
|
||||||
|
|
||||||
|
def get_all_new_forward_event_rows(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
|
||||||
|
" state_key, redacts, relates_to_id"
|
||||||
|
" FROM events AS e"
|
||||||
|
" LEFT JOIN redactions USING (event_id)"
|
||||||
|
" LEFT JOIN state_events USING (event_id)"
|
||||||
|
" LEFT JOIN event_relations USING (event_id)"
|
||||||
|
" WHERE ? < stream_ordering AND stream_ordering <= ?"
|
||||||
|
" ORDER BY stream_ordering ASC"
|
||||||
|
" LIMIT ?"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (last_id, current_id, limit))
|
||||||
|
new_event_updates = txn.fetchall()
|
||||||
|
|
||||||
|
if len(new_event_updates) == limit:
|
||||||
|
upper_bound = new_event_updates[-1][0]
|
||||||
|
else:
|
||||||
|
upper_bound = current_id
|
||||||
|
|
||||||
|
sql = (
|
||||||
|
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
|
||||||
|
" state_key, redacts, relates_to_id"
|
||||||
|
" FROM events AS e"
|
||||||
|
" INNER JOIN ex_outlier_stream USING (event_id)"
|
||||||
|
" LEFT JOIN redactions USING (event_id)"
|
||||||
|
" LEFT JOIN state_events USING (event_id)"
|
||||||
|
" LEFT JOIN event_relations USING (event_id)"
|
||||||
|
" WHERE ? < event_stream_ordering"
|
||||||
|
" AND event_stream_ordering <= ?"
|
||||||
|
" ORDER BY event_stream_ordering DESC"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (last_id, upper_bound))
|
||||||
|
new_event_updates.extend(txn)
|
||||||
|
|
||||||
|
return new_event_updates
|
||||||
|
|
||||||
|
return self.db.runInteraction(
|
||||||
|
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
|
||||||
|
if last_id == current_id:
|
||||||
|
return defer.succeed([])
|
||||||
|
|
||||||
|
def get_all_new_backfill_event_rows(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
|
||||||
|
" state_key, redacts, relates_to_id"
|
||||||
|
" FROM events AS e"
|
||||||
|
" LEFT JOIN redactions USING (event_id)"
|
||||||
|
" LEFT JOIN state_events USING (event_id)"
|
||||||
|
" LEFT JOIN event_relations USING (event_id)"
|
||||||
|
" WHERE ? > stream_ordering AND stream_ordering >= ?"
|
||||||
|
" ORDER BY stream_ordering ASC"
|
||||||
|
" LIMIT ?"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (-last_id, -current_id, limit))
|
||||||
|
new_event_updates = txn.fetchall()
|
||||||
|
|
||||||
|
if len(new_event_updates) == limit:
|
||||||
|
upper_bound = new_event_updates[-1][0]
|
||||||
|
else:
|
||||||
|
upper_bound = current_id
|
||||||
|
|
||||||
|
sql = (
|
||||||
|
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
|
||||||
|
" state_key, redacts, relates_to_id"
|
||||||
|
" FROM events AS e"
|
||||||
|
" INNER JOIN ex_outlier_stream USING (event_id)"
|
||||||
|
" LEFT JOIN redactions USING (event_id)"
|
||||||
|
" LEFT JOIN state_events USING (event_id)"
|
||||||
|
" LEFT JOIN event_relations USING (event_id)"
|
||||||
|
" WHERE ? > event_stream_ordering"
|
||||||
|
" AND event_stream_ordering >= ?"
|
||||||
|
" ORDER BY event_stream_ordering DESC"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (-last_id, -upper_bound))
|
||||||
|
new_event_updates.extend(txn.fetchall())
|
||||||
|
|
||||||
|
return new_event_updates
|
||||||
|
|
||||||
|
return self.db.runInteraction(
|
||||||
|
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
|
||||||
|
def get_all_updated_current_state_deltas_txn(txn):
|
||||||
|
sql = """
|
||||||
|
SELECT stream_id, room_id, type, state_key, event_id
|
||||||
|
FROM current_state_delta_stream
|
||||||
|
WHERE ? < stream_id AND stream_id <= ?
|
||||||
|
ORDER BY stream_id ASC LIMIT ?
|
||||||
|
"""
|
||||||
|
txn.execute(sql, (from_token, to_token, limit))
|
||||||
|
return txn.fetchall()
|
||||||
|
|
||||||
|
return self.db.runInteraction(
|
||||||
|
"get_all_updated_current_state_deltas",
|
||||||
|
get_all_updated_current_state_deltas_txn,
|
||||||
|
)
|
||||||
|
|
|
@ -732,6 +732,26 @@ class RoomWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
return total_media_quarantined
|
return total_media_quarantined
|
||||||
|
|
||||||
|
def get_all_new_public_rooms(self, prev_id, current_id, limit):
|
||||||
|
def get_all_new_public_rooms(txn):
|
||||||
|
sql = """
|
||||||
|
SELECT stream_id, room_id, visibility, appservice_id, network_id
|
||||||
|
FROM public_room_list_stream
|
||||||
|
WHERE stream_id > ? AND stream_id <= ?
|
||||||
|
ORDER BY stream_id ASC
|
||||||
|
LIMIT ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
txn.execute(sql, (prev_id, current_id, limit))
|
||||||
|
return txn.fetchall()
|
||||||
|
|
||||||
|
if prev_id == current_id:
|
||||||
|
return defer.succeed([])
|
||||||
|
|
||||||
|
return self.db.runInteraction(
|
||||||
|
"get_all_new_public_rooms", get_all_new_public_rooms
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RoomBackgroundUpdateStore(SQLBaseStore):
|
class RoomBackgroundUpdateStore(SQLBaseStore):
|
||||||
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
|
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
|
||||||
|
@ -1249,26 +1269,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
||||||
def get_current_public_room_stream_id(self):
|
def get_current_public_room_stream_id(self):
|
||||||
return self._public_room_id_gen.get_current_token()
|
return self._public_room_id_gen.get_current_token()
|
||||||
|
|
||||||
def get_all_new_public_rooms(self, prev_id, current_id, limit):
|
|
||||||
def get_all_new_public_rooms(txn):
|
|
||||||
sql = """
|
|
||||||
SELECT stream_id, room_id, visibility, appservice_id, network_id
|
|
||||||
FROM public_room_list_stream
|
|
||||||
WHERE stream_id > ? AND stream_id <= ?
|
|
||||||
ORDER BY stream_id ASC
|
|
||||||
LIMIT ?
|
|
||||||
"""
|
|
||||||
|
|
||||||
txn.execute(sql, (prev_id, current_id, limit))
|
|
||||||
return txn.fetchall()
|
|
||||||
|
|
||||||
if prev_id == current_id:
|
|
||||||
return defer.succeed([])
|
|
||||||
|
|
||||||
return self.db.runInteraction(
|
|
||||||
"get_all_new_public_rooms", get_all_new_public_rooms
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def block_room(self, room_id, user_id):
|
def block_room(self, room_id, user_id):
|
||||||
"""Marks the room as blocked. Can be called multiple times.
|
"""Marks the room as blocked. Can be called multiple times.
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
from synapse.replication.tcp.commands import ReplicateCommand
|
from synapse.replication.tcp.commands import ReplicateCommand
|
||||||
|
@ -29,19 +30,37 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
# build a replication server
|
# build a replication server
|
||||||
server_factory = ReplicationStreamProtocolFactory(self.hs)
|
server_factory = ReplicationStreamProtocolFactory(self.hs)
|
||||||
self.streamer = server_factory.streamer
|
self.streamer = server_factory.streamer
|
||||||
server = server_factory.buildProtocol(None)
|
self.server = server_factory.buildProtocol(None)
|
||||||
|
|
||||||
# build a replication client, with a dummy handler
|
self.test_handler = Mock(wraps=TestReplicationClientHandler())
|
||||||
handler_factory = Mock()
|
|
||||||
self.test_handler = TestReplicationClientHandler()
|
|
||||||
self.test_handler.factory = handler_factory
|
|
||||||
self.client = ClientReplicationStreamProtocol(
|
self.client = ClientReplicationStreamProtocol(
|
||||||
"client", "test", clock, self.test_handler
|
hs, "client", "test", clock, self.test_handler,
|
||||||
)
|
)
|
||||||
|
|
||||||
# wire them together
|
self._client_transport = None
|
||||||
self.client.makeConnection(FakeTransport(server, reactor))
|
self._server_transport = None
|
||||||
server.makeConnection(FakeTransport(self.client, reactor))
|
|
||||||
|
def reconnect(self):
|
||||||
|
if self._client_transport:
|
||||||
|
self.client.close()
|
||||||
|
|
||||||
|
if self._server_transport:
|
||||||
|
self.server.close()
|
||||||
|
|
||||||
|
self._client_transport = FakeTransport(self.server, self.reactor)
|
||||||
|
self.client.makeConnection(self._client_transport)
|
||||||
|
|
||||||
|
self._server_transport = FakeTransport(self.client, self.reactor)
|
||||||
|
self.server.makeConnection(self._server_transport)
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
if self._client_transport:
|
||||||
|
self._client_transport = None
|
||||||
|
self.client.close()
|
||||||
|
|
||||||
|
if self._server_transport:
|
||||||
|
self._server_transport = None
|
||||||
|
self.server.close()
|
||||||
|
|
||||||
def replicate(self):
|
def replicate(self):
|
||||||
"""Tell the master side of replication that something has happened, and then
|
"""Tell the master side of replication that something has happened, and then
|
||||||
|
@ -50,19 +69,24 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
self.streamer.on_notifier_poke()
|
self.streamer.on_notifier_poke()
|
||||||
self.pump(0.1)
|
self.pump(0.1)
|
||||||
|
|
||||||
def replicate_stream(self, stream, token="NOW"):
|
def replicate_stream(self):
|
||||||
"""Make the client end a REPLICATE command to set up a subscription to a stream"""
|
"""Make the client end a REPLICATE command to set up a subscription to a stream"""
|
||||||
self.client.send_command(ReplicateCommand(stream, token))
|
self.client.send_command(ReplicateCommand())
|
||||||
|
|
||||||
|
|
||||||
class TestReplicationClientHandler(object):
|
class TestReplicationClientHandler(object):
|
||||||
"""Drop-in for ReplicationClientHandler which just collects RDATA rows"""
|
"""Drop-in for ReplicationClientHandler which just collects RDATA rows"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.received_rdata_rows = []
|
self.streams = set()
|
||||||
|
self._received_rdata_rows = []
|
||||||
|
|
||||||
def get_streams_to_replicate(self):
|
def get_streams_to_replicate(self):
|
||||||
return {}
|
positions = {s: 0 for s in self.streams}
|
||||||
|
for stream, token, _ in self._received_rdata_rows:
|
||||||
|
if stream in self.streams:
|
||||||
|
positions[stream] = max(token, positions.get(stream, 0))
|
||||||
|
return positions
|
||||||
|
|
||||||
def get_currently_syncing_users(self):
|
def get_currently_syncing_users(self):
|
||||||
return []
|
return []
|
||||||
|
@ -73,6 +97,9 @@ class TestReplicationClientHandler(object):
|
||||||
def finished_connecting(self):
|
def finished_connecting(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def on_position(self, stream_name, token):
|
||||||
|
"""Called when we get new position data."""
|
||||||
|
|
||||||
async def on_rdata(self, stream_name, token, rows):
|
async def on_rdata(self, stream_name, token, rows):
|
||||||
for r in rows:
|
for r in rows:
|
||||||
self.received_rdata_rows.append((stream_name, token, r))
|
self._received_rdata_rows.append((stream_name, token, r))
|
||||||
|
|
|
@ -17,30 +17,64 @@ from synapse.replication.tcp.streams._base import ReceiptsStream
|
||||||
from tests.replication.tcp.streams._base import BaseStreamTestCase
|
from tests.replication.tcp.streams._base import BaseStreamTestCase
|
||||||
|
|
||||||
USER_ID = "@feeling:blue"
|
USER_ID = "@feeling:blue"
|
||||||
ROOM_ID = "!room:blue"
|
|
||||||
EVENT_ID = "$event:blue"
|
|
||||||
|
|
||||||
|
|
||||||
class ReceiptsStreamTestCase(BaseStreamTestCase):
|
class ReceiptsStreamTestCase(BaseStreamTestCase):
|
||||||
def test_receipt(self):
|
def test_receipt(self):
|
||||||
|
self.reconnect()
|
||||||
|
|
||||||
# make the client subscribe to the receipts stream
|
# make the client subscribe to the receipts stream
|
||||||
self.replicate_stream("receipts", "NOW")
|
self.replicate_stream()
|
||||||
|
self.test_handler.streams.add("receipts")
|
||||||
|
|
||||||
# tell the master to send a new receipt
|
# tell the master to send a new receipt
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_datastore().insert_receipt(
|
self.hs.get_datastore().insert_receipt(
|
||||||
ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1}
|
"!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.replicate()
|
self.replicate()
|
||||||
|
|
||||||
# there should be one RDATA command
|
# there should be one RDATA command
|
||||||
rdata_rows = self.test_handler.received_rdata_rows
|
self.test_handler.on_rdata.assert_called_once()
|
||||||
|
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||||
|
self.assertEqual(stream_name, "receipts")
|
||||||
self.assertEqual(1, len(rdata_rows))
|
self.assertEqual(1, len(rdata_rows))
|
||||||
self.assertEqual(rdata_rows[0][0], "receipts")
|
row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
|
||||||
row = rdata_rows[0][2] # type: ReceiptsStream.ReceiptsStreamRow
|
self.assertEqual("!room:blue", row.room_id)
|
||||||
self.assertEqual(ROOM_ID, row.room_id)
|
|
||||||
self.assertEqual("m.read", row.receipt_type)
|
self.assertEqual("m.read", row.receipt_type)
|
||||||
self.assertEqual(USER_ID, row.user_id)
|
self.assertEqual(USER_ID, row.user_id)
|
||||||
self.assertEqual(EVENT_ID, row.event_id)
|
self.assertEqual("$event:blue", row.event_id)
|
||||||
self.assertEqual({"a": 1}, row.data)
|
self.assertEqual({"a": 1}, row.data)
|
||||||
|
|
||||||
|
# Now let's disconnect and insert some data.
|
||||||
|
self.disconnect()
|
||||||
|
|
||||||
|
self.test_handler.on_rdata.reset_mock()
|
||||||
|
|
||||||
|
self.get_success(
|
||||||
|
self.hs.get_datastore().insert_receipt(
|
||||||
|
"!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.replicate()
|
||||||
|
|
||||||
|
# Nothing should have happened as we are disconnected
|
||||||
|
self.test_handler.on_rdata.assert_not_called()
|
||||||
|
|
||||||
|
self.reconnect()
|
||||||
|
self.pump(0.1)
|
||||||
|
|
||||||
|
# We should now have caught up and get the missing data
|
||||||
|
self.test_handler.on_rdata.assert_called_once()
|
||||||
|
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||||
|
self.assertEqual(stream_name, "receipts")
|
||||||
|
self.assertEqual(token, 3)
|
||||||
|
self.assertEqual(1, len(rdata_rows))
|
||||||
|
|
||||||
|
row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
|
||||||
|
self.assertEqual("!room2:blue", row.room_id)
|
||||||
|
self.assertEqual("m.read", row.receipt_type)
|
||||||
|
self.assertEqual(USER_ID, row.user_id)
|
||||||
|
self.assertEqual("$event2:foo", row.event_id)
|
||||||
|
self.assertEqual({"a": 2}, row.data)
|
||||||
|
|
Loading…
Reference in a new issue