Track device list updates per room. (#12321)

This is a first step in dealing with #7721.

The idea is basically that rather than calculating the full set of users a device list update needs to be sent to up front, we instead simply record the rooms the user was in at the time of the change. This will allow a few things:

1. we can defer calculating the set of remote servers that need to be poked about the change; and
2. during `/sync` and `/keys/changes` we can avoid also avoid calculating users who share rooms with other users, and instead just look at the rooms that have changed.

However, care needs to be taken to correctly handle server downgrades. As such this PR writes to both `device_lists_changes_in_room` and the `device_lists_outbound_pokes` table synchronously. In a future release we can then bump the database schema compat version to `69` and then we can assume that the new `device_lists_changes_in_room` exists and is handled.

There is a temporary option to disable writing to `device_lists_outbound_pokes` synchronously, allowing us to test the new code path does work (and by implication upgrading to a future release and downgrading to this one will work correctly).

Note: Ideally we'd do the calculation of room to servers on a worker (e.g. the background worker), but currently only master can write to the `device_list_outbound_pokes` table.
This commit is contained in:
Erik Johnston 2022-04-04 15:25:20 +01:00 committed by GitHub
parent 80839a44f1
commit 5c9e39e619
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 390 additions and 47 deletions

1
changelog.d/12321.misc Normal file
View file

@ -0,0 +1 @@
Add ground work for speeding up device list updates for users in large numbers of rooms.

View file

@ -97,6 +97,7 @@ BOOLEAN_COLUMNS = {
"users": ["shadow_banned"],
"e2e_fallback_keys_json": ["used"],
"access_tokens": ["used"],
"device_lists_changes_in_room": ["converted_to_destinations"],
}

View file

@ -680,6 +680,14 @@ class ServerConfig(Config):
config.get("use_account_validity_in_account_status") or False
)
# This is a temporary option that enables fully using the new
# `device_lists_changes_in_room` without the backwards compat code. This
# is primarily for testing. If enabled the server should *not* be
# downgraded, as it may lead to missing device list updates.
self.use_new_device_lists_changes_in_room = (
config.get("use_new_device_lists_changes_in_room") or False
)
self.rooms_to_exclude_from_sync: List[str] = (
config.get("exclude_rooms_from_sync") or []
)

View file

@ -37,7 +37,10 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
from synapse.types import (
JsonDict,
StreamToken,
@ -278,6 +281,22 @@ class DeviceHandler(DeviceWorkerHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
# Whether `_handle_new_device_update_async` is currently processing.
self._handle_new_device_update_is_processing = False
# If a new device update may have happened while the loop was
# processing.
self._handle_new_device_update_new_data = False
# On start up check if there are any updates pending.
hs.get_reactor().callWhenRunning(self._handle_new_device_update_async)
# Used to decide if we calculate outbound pokes up front or not. By
# default we do to allow safely downgrading Synapse.
self.use_new_device_lists_changes_in_room = (
hs.config.server.use_new_device_lists_changes_in_room
)
def _check_device_name_length(self, name: Optional[str]) -> None:
"""
Checks whether a device name is longer than the maximum allowed length.
@ -469,19 +488,26 @@ class DeviceHandler(DeviceWorkerHandler):
# No changes to notify about, so this is a no-op.
return
users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id
)
room_ids = await self.store.get_rooms_for_user(user_id)
hosts: Set[str] = set()
if self.hs.is_mine_id(user_id):
hosts.update(get_domain_from_id(u) for u in users_who_share_room)
hosts.discard(self.server_name)
hosts: Optional[Set[str]] = None
if not self.use_new_device_lists_changes_in_room:
hosts = set()
set_tag("target_hosts", hosts)
if self.hs.is_mine_id(user_id):
for room_id in room_ids:
joined_users = await self.store.get_users_in_room(room_id)
hosts.update(get_domain_from_id(u) for u in joined_users)
set_tag("target_hosts", hosts)
hosts.discard(self.server_name)
position = await self.store.add_device_change_to_streams(
user_id, device_ids, list(hosts)
user_id,
device_ids,
hosts=hosts,
room_ids=room_ids,
)
if not position:
@ -495,9 +521,12 @@ class DeviceHandler(DeviceWorkerHandler):
# specify the user ID too since the user should always get their own device list
# updates, even if they aren't in any rooms.
users_to_notify = users_who_share_room.union({user_id})
self.notifier.on_new_event(
"device_list_key", position, users={user_id}, rooms=room_ids
)
self.notifier.on_new_event("device_list_key", position, users=users_to_notify)
# We may need to do some processing asynchronously.
self._handle_new_device_update_async()
if hosts:
logger.info(
@ -614,6 +643,85 @@ class DeviceHandler(DeviceWorkerHandler):
return {"success": True}
@wrap_as_background_process("_handle_new_device_update_async")
async def _handle_new_device_update_async(self) -> None:
"""Called when we have a new local device list update that we need to
send out over federation.
This happens in the background so as not to block the original request
that generated the device update.
"""
if self._handle_new_device_update_is_processing:
self._handle_new_device_update_new_data = True
return
self._handle_new_device_update_is_processing = True
# The stream ID we processed previous iteration (if any), and the set of
# hosts we've already poked about for this update. This is so that we
# don't poke the same remote server about the same update repeatedly.
current_stream_id = None
hosts_already_sent_to: Set[str] = set()
try:
while True:
self._handle_new_device_update_new_data = False
rows = await self.store.get_uncoverted_outbound_room_pokes()
if not rows:
# If the DB returned nothing then there is nothing left to
# do, *unless* a new device list update happened during the
# DB query.
if self._handle_new_device_update_new_data:
continue
else:
return
for user_id, device_id, room_id, stream_id, opentracing_context in rows:
joined_user_ids = await self.store.get_users_in_room(room_id)
hosts = {get_domain_from_id(u) for u in joined_user_ids}
hosts.discard(self.server_name)
# Check if we've already sent this update to some hosts
if current_stream_id == stream_id:
hosts -= hosts_already_sent_to
await self.store.add_device_list_outbound_pokes(
user_id=user_id,
device_id=device_id,
room_id=room_id,
stream_id=stream_id,
hosts=hosts,
context=opentracing_context,
)
# Notify replication that we've updated the device list stream.
self.notifier.notify_replication()
if hosts:
logger.info(
"Sending device list update notif for %r to: %r",
user_id,
hosts,
)
for host in hosts:
self.federation_sender.send_device_messages(
host, immediate=False
)
log_kv(
{"message": "sent device update to host", "host": host}
)
if current_stream_id != stream_id:
# Clear the set of hosts we've already sent to as we're
# processing a new update.
hosts_already_sent_to.clear()
hosts_already_sent_to.update(hosts)
current_stream_id = stream_id
finally:
self._handle_new_device_update_is_processing = False
def _update_device_from_client_ips(
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]

View file

@ -44,6 +44,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
("device_lists_changes_in_room", "stream_id"),
],
)
device_list_max = self._device_list_id_gen.get_current_token()

View file

@ -146,6 +146,7 @@ class DataStore(
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
("device_lists_changes_in_room", "stream_id"),
],
)

View file

@ -810,6 +810,7 @@ class DeviceWorkerStore(SQLBaseStore):
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
) AS e
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
@ -1528,7 +1529,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
async def add_device_change_to_streams(
self, user_id: str, device_ids: Collection[str], hosts: Collection[str]
self,
user_id: str,
device_ids: Collection[str],
hosts: Optional[Collection[str]],
room_ids: Collection[str],
) -> Optional[int]:
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
@ -1537,7 +1542,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: The ID of the user whose device changed.
device_ids: The IDs of any changed devices. If empty, this function will
return None.
hosts: The remote destinations that should be notified of the change.
hosts: The remote destinations that should be notified of the change. If
None then the set of hosts have *not* been calculated, and will be
calculated later by a background task.
room_ids: The rooms that the user is in
Returns:
The maximum stream ID of device list updates that were added to the database, or
@ -1546,34 +1554,62 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not device_ids:
return None
async with self._device_list_id_gen.get_next_mult(
len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
"add_device_change_to_stream",
self._add_device_change_to_stream_txn,
context = get_active_span_text_map()
def add_device_changes_txn(
txn, stream_ids_for_device_change, stream_ids_for_outbound_pokes
):
self._add_device_change_to_stream_txn(
txn,
user_id,
device_ids,
stream_ids,
stream_ids_for_device_change,
)
if not hosts:
return stream_ids[-1]
self._add_device_outbound_room_poke_txn(
txn,
user_id,
device_ids,
room_ids,
stream_ids_for_device_change,
context,
hosts_have_been_calculated=hosts is not None,
)
context = get_active_span_text_map()
async with self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
"add_device_outbound_poke_to_stream",
self._add_device_outbound_poke_to_stream_txn,
# If the set of hosts to send to has not been calculated yet (and so
# `hosts` is None) or there are no `hosts` to send to, then skip
# trying to persist them to the DB.
if not hosts:
return
self._add_device_outbound_poke_to_stream_txn(
txn,
user_id,
device_ids,
hosts,
stream_ids,
stream_ids_for_outbound_pokes,
context,
)
# `device_lists_stream` wants a stream ID per device update.
num_stream_ids = len(device_ids)
if hosts:
# `device_lists_outbound_pokes` wants a different stream ID for
# each row, which is a row per host per device update.
num_stream_ids += len(hosts) * len(device_ids)
async with self._device_list_id_gen.get_next_mult(num_stream_ids) as stream_ids:
stream_ids_for_device_change = stream_ids[: len(device_ids)]
stream_ids_for_outbound_pokes = stream_ids[len(device_ids) :]
await self.db_pool.runInteraction(
"add_device_change_to_stream",
add_device_changes_txn,
stream_ids_for_device_change,
stream_ids_for_outbound_pokes,
)
return stream_ids[-1]
def _add_device_change_to_stream_txn(
@ -1617,7 +1653,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str,
device_ids: Iterable[str],
hosts: Collection[str],
stream_ids: List[str],
stream_ids: List[int],
context: Dict[str, str],
) -> None:
for host in hosts:
@ -1628,8 +1664,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
now = self._clock.time_msec()
next_stream_id = iter(stream_ids)
stream_id_iterator = iter(stream_ids)
encoded_context = json_encoder.encode(context)
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
@ -1645,16 +1682,146 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values=[
(
destination,
next(next_stream_id),
next(stream_id_iterator),
user_id,
device_id,
False,
now,
json_encoder.encode(context)
if whitelisted_homeserver(destination)
else "{}",
encoded_context if whitelisted_homeserver(destination) else "{}",
)
for destination in hosts
for device_id in device_ids
],
)
def _add_device_outbound_room_poke_txn(
self,
txn: LoggingTransaction,
user_id: str,
device_ids: Iterable[str],
room_ids: Collection[str],
stream_ids: List[str],
context: Dict[str, str],
hosts_have_been_calculated: bool,
) -> None:
"""Record the user in the room has updated their device.
Args:
hosts_have_been_calculated: True if `device_lists_outbound_pokes`
has been updated already with the updates.
"""
# We only need to convert to outbound pokes if they are our user.
converted_to_destinations = (
hosts_have_been_calculated or not self.hs.is_mine_id(user_id)
)
encoded_context = json_encoder.encode(context)
# The `device_lists_changes_in_room.stream_id` column matches the
# corresponding `stream_id` of the update in the `device_lists_stream`
# table, i.e. all rows persisted for the same device update will have
# the same `stream_id` (but different room IDs).
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_changes_in_room",
keys=(
"user_id",
"device_id",
"room_id",
"stream_id",
"converted_to_destinations",
"opentracing_context",
),
values=[
(
user_id,
device_id,
room_id,
stream_id,
converted_to_destinations,
encoded_context,
)
for room_id in room_ids
for device_id, stream_id in zip(device_ids, stream_ids)
],
)
async def get_uncoverted_outbound_room_pokes(
self, limit: int = 10
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
"""Get device list changes by room that have not yet been handled and
written to `device_lists_outbound_pokes`.
Returns:
A list of user ID, device ID, room ID, stream ID and optional opentracing context.
"""
sql = """
SELECT user_id, device_id, room_id, stream_id, opentracing_context
FROM device_lists_changes_in_room
WHERE NOT converted_to_destinations
ORDER BY stream_id
LIMIT ?
"""
def get_uncoverted_outbound_room_pokes_txn(txn):
txn.execute(sql, (limit,))
return txn.fetchall()
return await self.db_pool.runInteraction(
"get_uncoverted_outbound_room_pokes", get_uncoverted_outbound_room_pokes_txn
)
async def add_device_list_outbound_pokes(
self,
user_id: str,
device_id: str,
room_id: str,
stream_id: int,
hosts: Collection[str],
context: Optional[Dict[str, str]],
) -> None:
"""Queue the device update to be sent to the given set of hosts,
calculated from the room ID.
Marks the associated row in `device_lists_changes_in_room` as handled.
"""
def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]):
if hosts:
self._add_device_outbound_poke_to_stream_txn(
txn,
user_id=user_id,
device_ids=[device_id],
hosts=hosts,
stream_ids=stream_ids,
context=context,
)
self.db_pool.simple_update_txn(
txn,
table="device_lists_changes_in_room",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"stream_id": stream_id,
"room_id": room_id,
},
updatevalues={"converted_to_destinations": True},
)
if not hosts:
# If there are no hosts then we don't try and generate stream IDs.
return await self.db_pool.runInteraction(
"add_device_list_outbound_pokes",
add_device_list_outbound_pokes_txn,
[],
)
async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
return await self.db_pool.runInteraction(
"add_device_list_outbound_pokes",
add_device_list_outbound_pokes_txn,
stream_ids,
)

View file

@ -60,6 +60,7 @@ Changes in SCHEMA_VERSION = 68:
new events.
Changes in SCHEMA_VERSION = 69:
- We now write to `device_lists_changes_in_room` table.
- Use sequence to generate future `application_services_txns.txn_id`s
"""

View file

@ -0,0 +1,38 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE device_lists_changes_in_room (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
room_id TEXT NOT NULL,
-- This initially matches `device_lists_stream.stream_id`. Note that we
-- delete older values from `device_lists_stream`, so we can't use a foreign
-- constraint here.
--
-- The table will contain rows with the same `stream_id` but different
-- `room_id`, as for each device update we store a row per room the user is
-- joined to. Therefore `(stream_id, room_id)` gives a unique index.
stream_id BIGINT NOT NULL,
-- We have a background process which goes through this table and converts
-- entries into rows in `device_lists_outbound_pokes`. Once we have processed
-- a row, we mark it as such by setting `converted_to_destinations=TRUE`.
converted_to_destinations BOOLEAN NOT NULL,
opentracing_context TEXT
);
CREATE UNIQUE INDEX device_lists_changes_in_stream_id ON device_lists_changes_in_room(stream_id, room_id);
CREATE INDEX device_lists_changes_in_stream_id_unconverted ON device_lists_changes_in_room(stream_id) WHERE NOT converted_to_destinations;

View file

@ -14,6 +14,7 @@
from typing import Optional
from unittest.mock import Mock
from parameterized import parameterized_class
from signedjson import key, sign
from signedjson.types import BaseKey, SigningKey
@ -154,6 +155,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
)
@parameterized_class(
[
{"enable_room_poke_code_path": False},
{"enable_room_poke_code_path": True},
]
)
class FederationSenderDevicesTestCases(HomeserverTestCase):
servlets = [
admin.register_servlets,
@ -168,17 +175,21 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
def default_config(self):
c = super().default_config()
c["send_federation"] = True
c["use_new_device_lists_changes_in_room"] = self.enable_room_poke_code_path
return c
def prepare(self, reactor, clock, hs):
# stub out get_users_who_share_room_with_user so that it claims that
# `@user2:host2` is in the room
def get_users_who_share_room_with_user(user_id):
# stub out `get_rooms_for_user` and `get_users_in_room` so that the
# server thinks the user shares a room with `@user2:host2`
def get_rooms_for_user(user_id):
return defer.succeed({"!room:host1"})
hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user
def get_users_in_room(room_id):
return defer.succeed({"@user2:host2"})
hs.get_datastores().main.get_users_who_share_room_with_user = (
get_users_who_share_room_with_user
)
hs.get_datastores().main.get_users_in_room = get_users_in_room
# whenever send_transaction is called, record the edu data
self.edus = []

View file

@ -96,7 +96,9 @@ class DeviceStoreTestCase(HomeserverTestCase):
# Add two device updates with sequential `stream_id`s
self.get_success(
self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
self.store.add_device_change_to_streams(
"user_id", device_ids, ["somehost"], ["!some:room"]
)
)
# Get all device updates ever meant for this remote
@ -122,7 +124,9 @@ class DeviceStoreTestCase(HomeserverTestCase):
"device_id5",
]
self.get_success(
self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
self.store.add_device_change_to_streams(
"user_id", device_ids, ["somehost"], ["!some:room"]
)
)
# Get device updates meant for this remote
@ -144,7 +148,9 @@ class DeviceStoreTestCase(HomeserverTestCase):
# Add some more device updates to ensure it still resumes properly
device_ids = ["device_id6", "device_id7"]
self.get_success(
self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
self.store.add_device_change_to_streams(
"user_id", device_ids, ["somehost"], ["!some:room"]
)
)
# Get the next batch of device updates
@ -220,7 +226,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
self.get_success(
self.store.add_device_change_to_streams(
"@user_id:test", device_ids, ["somehost"]
"@user_id:test", device_ids, ["somehost"], ["!some:room"]
)
)