Pass instance name through to rdata

This commit is contained in:
Erik Johnston 2020-03-25 14:05:53 +00:00
parent 092b62ee7b
commit 0473f87a17
3 changed files with 23 additions and 12 deletions

View file

@ -608,9 +608,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
else: else:
self.send_handler = None self.send_handler = None
async def on_rdata(self, stream_name, token, rows): async def on_rdata(self, stream_name, instance_name, token, rows):
await super().on_rdata(stream_name, token, rows) await super().on_rdata(stream_name, instance_name, token, rows)
run_in_background(self.process_and_notify, stream_name, token, rows) run_in_background(
self.process_and_notify, stream_name, instance_name, token, rows
)
def get_streams_to_replicate(self): def get_streams_to_replicate(self):
args = super().get_streams_to_replicate() args = super().get_streams_to_replicate()
@ -619,7 +621,7 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
args.update(self.send_handler.stream_positions()) args.update(self.send_handler.stream_positions())
return args return args
async def process_and_notify(self, stream_name, token, rows): async def process_and_notify(self, stream_name, instance_name, token, rows):
try: try:
if self.send_handler: if self.send_handler:
self.send_handler.process_replication_rows(stream_name, token, rows) self.send_handler.process_replication_rows(stream_name, token, rows)

View file

@ -65,7 +65,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
limit = parse_integer(request, "limit", required=True) limit = parse_integer(request, "limit", required=True)
updates, upto_token, limited = await stream.get_updates_since( updates, upto_token, limited = await stream.get_updates_since(
from_token, upto_token, limit self.instance_name, from_token, upto_token, limit
) )
return ( return (

View file

@ -207,9 +207,11 @@ class ReplicationClientHandler:
# Check if this is the last of a batch of updates # Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, []) rows = self.pending_batches.pop(stream_name, [])
rows.append(row) rows.append(row)
await self.on_rdata(stream_name, cmd.token, rows) await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
async def on_rdata(self, stream_name: str, token: int, rows: list): async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
):
"""Called to handle a batch of replication data with a given stream token. """Called to handle a batch of replication data with a given stream token.
Args: Args:
@ -218,8 +220,10 @@ class ReplicationClientHandler:
rows: a list of Stream.ROW_TYPE objects as returned by rows: a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row. Stream.parse_row.
""" """
logger.info("Received rdata %s -> %s", stream_name, token) logger.info("Received rdata %s %s -> %s", stream_name, instance_name, token)
await self.replication_data_handler.on_rdata(stream_name, token, rows) await self.replication_data_handler.on_rdata(
stream_name, instance_name, token, rows
)
async def on_POSITION(self, cmd: PositionCommand): async def on_POSITION(self, cmd: PositionCommand):
stream = self.streams.get(cmd.stream_name) stream = self.streams.get(cmd.stream_name)
@ -243,11 +247,12 @@ class ReplicationClientHandler:
limited = cmd.token != current_token limited = cmd.token != current_token
while limited: while limited:
updates, current_token, limited = await stream.get_updates_since( updates, current_token, limited = await stream.get_updates_since(
current_token, cmd.token cmd.instance_name, current_token, cmd.token
) )
if updates: if updates:
await self.on_rdata( await self.on_rdata(
cmd.stream_name, cmd.stream_name,
cmd.instance_name,
current_token, current_token,
[stream.parse_row(update[1]) for update in updates], [stream.parse_row(update[1]) for update in updates],
) )
@ -258,7 +263,9 @@ class ReplicationClientHandler:
# Handle any RDATA that came in while we were catching up. # Handle any RDATA that came in while we were catching up.
rows = self.pending_batches.pop(cmd.stream_name, []) rows = self.pending_batches.pop(cmd.stream_name, [])
if rows: if rows:
await self.on_rdata(cmd.stream_name, rows[-1].token, rows) await self.on_rdata(
cmd.stream_name, cmd.instance_name, rows[-1].token, rows
)
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
"""Called when get a new REMOTE_SERVER_UP command.""" """Called when get a new REMOTE_SERVER_UP command."""
@ -342,7 +349,9 @@ class ReplicationDataHandler:
self.slaved_store = hs.config.worker_app is not None self.slaved_store = hs.config.worker_app is not None
self.slaved_typing = not hs.config.server.handle_typing self.slaved_typing = not hs.config.server.handle_typing
async def on_rdata(self, stream_name: str, token: int, rows: list): async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
):
"""Called to handle a batch of replication data with a given stream token. """Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to By default this just pokes the slave store. Can be overridden in subclasses to