Thread through instance name to replication client

This commit is contained in:
Erik Johnston 2020-03-25 11:24:48 +00:00
parent 6da24f2d5f
commit 9f15bffd72
2 changed files with 12 additions and 8 deletions

View file

@ -135,7 +135,10 @@ class ReplicationEndpoint(object):
@trace(opname="outgoing_replication_request")
@defer.inlineCallbacks
def send_request(**kwargs):
def send_request(instance_name="master", **kwargs):
if instance_name != "master":
raise Exception("Unknown instance")
data = yield cls._serialize_payload(**kwargs)
url_args = [

View file

@ -87,14 +87,14 @@ class Stream(object):
"""
current_token = self.current_token()
updates, current_token, limited = await self.get_updates_since(
self.last_token, current_token
"master", self.last_token, current_token
)
self.last_token = current_token
return updates, current_token, limited
async def get_updates_since(
self, from_token: Token, upto_token: Token, limit: int = 100
self, instance_name: str, from_token: Token, upto_token: Token, limit: int = 100
) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
"""Like get_updates except allows specifying from when we should
stream updates
@ -112,7 +112,7 @@ class Stream(object):
return [], upto_token, False
updates, upto_token, limited = await self.update_function(
from_token, upto_token, limit=limit,
instance_name, from_token, upto_token, limit=limit,
)
return updates, upto_token, limited
@ -137,13 +137,13 @@ class Stream(object):
def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
query_function: Callable[[str, Token, Token, int], Awaitable[List[tuple]]]
) -> Callable[[str, Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
"""Wraps a db query function which returns a list of rows to make it
suitable for use as an `update_function` for the Stream class
"""
async def update_function(from_token, upto_token, limit):
async def update_function(instance_name, from_token, upto_token, limit):
rows = await query_function(from_token, upto_token, limit)
updates = [(row[0], row[1:]) for row in rows]
limited = False
@ -166,9 +166,10 @@ def make_http_update_function(
client = ReplicationGetStreamUpdates.make_client(hs)
async def update_function(
from_token: int, upto_token: int, limit: int
instance_name: str, from_token: int, upto_token: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
return await client(
instance_name=instance_name,
stream_name=stream_name,
from_token=from_token,
upto_token=upto_token,