Merge branch 'madlittlemods/msc3575-sliding-sync-0.0.1' into madlittlemods/msc3575-sliding-sync-filtering

This commit is contained in:
Eric Eastwood 2024-06-06 14:40:59 -05:00
commit b457c0b2e2
8 changed files with 218 additions and 56 deletions

1
changelog.d/17271.misc Normal file
View file

@ -0,0 +1 @@
Handle OTK uploads off master.

1
changelog.d/17273.misc Normal file
View file

@ -0,0 +1 @@
Don't try and resync devices for remote users whose servers are marked as down.

1
changelog.d/17275.bugfix Normal file
View file

@ -0,0 +1 @@
Fix bug where OTKs were not always included in `/sync` response when using workers.

View file

@ -35,6 +35,7 @@ from synapse.api.errors import CodeMessageException, Codes, NotFoundError, Synap
from synapse.handlers.device import DeviceHandler from synapse.handlers.device import DeviceHandler
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
from synapse.types import ( from synapse.types import (
JsonDict, JsonDict,
JsonMapping, JsonMapping,
@ -45,7 +46,10 @@ from synapse.types import (
from synapse.util import json_decoder from synapse.util import json_decoder
from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.cancellation import cancellable from synapse.util.cancellation import cancellable
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import (
NotRetryingDestination,
filter_destinations_by_retry_limiter,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -86,6 +90,12 @@ class E2eKeysHandler:
edu_updater.incoming_signing_key_update, edu_updater.incoming_signing_key_update,
) )
self.device_key_uploader = self.upload_device_keys_for_user
else:
self.device_key_uploader = (
ReplicationUploadKeysForUserRestServlet.make_client(hs)
)
# doesn't really work as part of the generic query API, because the # doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the # query request requires an object POST, but we abuse the
# "query handler" interface. # "query handler" interface.
@ -268,10 +278,8 @@ class E2eKeysHandler:
"%d destinations to query devices for", len(remote_queries_not_in_cache) "%d destinations to query devices for", len(remote_queries_not_in_cache)
) )
async def _query( async def _query(destination: str) -> None:
destination_queries: Tuple[str, Dict[str, Iterable[str]]] queries = remote_queries_not_in_cache[destination]
) -> None:
destination, queries = destination_queries
return await self._query_devices_for_destination( return await self._query_devices_for_destination(
results, results,
cross_signing_keys, cross_signing_keys,
@ -281,9 +289,20 @@ class E2eKeysHandler:
timeout, timeout,
) )
# Only try and fetch keys for destinations that are not marked as
# down.
filtered_destinations = await filter_destinations_by_retry_limiter(
remote_queries_not_in_cache.keys(),
self.clock,
self.store,
# Let's give an arbitrary grace period for those hosts that are
# only recently down
retry_due_within_ms=60 * 1000,
)
await concurrently_execute( await concurrently_execute(
_query, _query,
remote_queries_not_in_cache.items(), filtered_destinations,
10, 10,
delay_cancellation=True, delay_cancellation=True,
) )
@ -784,36 +803,17 @@ class E2eKeysHandler:
"one_time_keys": A mapping from algorithm to number of keys for that "one_time_keys": A mapping from algorithm to number of keys for that
algorithm, including those previously persisted. algorithm, including those previously persisted.
""" """
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
# TODO: Validate the JSON to make sure it has the right keys. # TODO: Validate the JSON to make sure it has the right keys.
device_keys = keys.get("device_keys", None) device_keys = keys.get("device_keys", None)
if device_keys: if device_keys:
logger.info( await self.device_key_uploader(
"Updating device_keys for device %r for user %s at %d", user_id=user_id,
device_id, device_id=device_id,
user_id, keys={"device_keys": device_keys},
time_now,
) )
log_kv(
{
"message": "Updating device_keys for user.",
"user_id": user_id,
"device_id": device_id,
}
)
# TODO: Sign the JSON with the server key
changed = await self.store.set_e2e_device_keys(
user_id, device_id, time_now, device_keys
)
if changed:
# Only notify about device updates *if* the keys actually changed
await self.device_handler.notify_device_update(user_id, [device_id])
else:
log_kv({"message": "Not updating device_keys for user", "user_id": user_id})
one_time_keys = keys.get("one_time_keys", None) one_time_keys = keys.get("one_time_keys", None)
if one_time_keys: if one_time_keys:
log_kv( log_kv(
@ -849,6 +849,49 @@ class E2eKeysHandler:
{"message": "Did not update fallback_keys", "reason": "no keys given"} {"message": "Did not update fallback_keys", "reason": "no keys given"}
) )
result = await self.store.count_e2e_one_time_keys(user_id, device_id)
set_tag("one_time_key_counts", str(result))
return {"one_time_key_counts": result}
@tag_args
async def upload_device_keys_for_user(
self, user_id: str, device_id: str, keys: JsonDict
) -> None:
"""
Args:
user_id: user whose keys are being uploaded.
device_id: device whose keys are being uploaded.
device_keys: the `device_keys` of an /keys/upload request.
"""
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)
time_now = self.clock.time_msec()
device_keys = keys["device_keys"]
logger.info(
"Updating device_keys for device %r for user %s at %d",
device_id,
user_id,
time_now,
)
log_kv(
{
"message": "Updating device_keys for user.",
"user_id": user_id,
"device_id": device_id,
}
)
# TODO: Sign the JSON with the server key
changed = await self.store.set_e2e_device_keys(
user_id, device_id, time_now, device_keys
)
if changed:
# Only notify about device updates *if* the keys actually changed
await self.device_handler.notify_device_update(user_id, [device_id])
# the device should have been registered already, but it may have been # the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an # deleted due to a race with a DELETE request. Or we may be using an
# old access_token without an associated device_id. Either way, we # old access_token without an associated device_id. Either way, we
@ -856,11 +899,6 @@ class E2eKeysHandler:
# keys without a corresponding device. # keys without a corresponding device.
await self.device_handler.check_device_registered(user_id, device_id) await self.device_handler.check_device_registered(user_id, device_id)
result = await self.store.count_e2e_one_time_keys(user_id, device_id)
set_tag("one_time_key_counts", str(result))
return {"one_time_key_counts": result}
async def _upload_one_time_keys_for_user( async def _upload_one_time_keys_for_user(
self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
) -> None: ) -> None:

View file

@ -53,14 +53,13 @@ def filter_membership_for_sync(*, membership: str, user_id: str, sender: str) ->
sender: The person who sent the membership event sender: The person who sent the membership event
""" """
return ( # Everything except `Membership.LEAVE` because we want everything that's *still*
# Everything except `Membership.LEAVE` because we want everything that's *still* # relevant to the user. There are few more things to include in the sync response
# relevant to the user. There are few more things to include in the sync response # (newly_left) but those are handled separately.
# (newly_left) but those are handled separately. #
membership in (Membership.LIST - {Membership.LEAVE}) # This logic includes kicks (leave events where the sender is not the same user) and
# Include kicks # can be read as "anything that isn't a leave or a leave with a different sender".
or (membership == Membership.LEAVE and sender != user_id) return membership != Membership.LEAVE or sender != user_id
)
class SlidingSyncConfig(SlidingSyncBody): class SlidingSyncConfig(SlidingSyncBody):

View file

@ -285,7 +285,11 @@ class SyncResult:
) )
@staticmethod @staticmethod
def empty(next_batch: StreamToken) -> "SyncResult": def empty(
next_batch: StreamToken,
device_one_time_keys_count: JsonMapping,
device_unused_fallback_key_types: List[str],
) -> "SyncResult":
"Return a new empty result" "Return a new empty result"
return SyncResult( return SyncResult(
next_batch=next_batch, next_batch=next_batch,
@ -297,8 +301,8 @@ class SyncResult:
archived=[], archived=[],
to_device=[], to_device=[],
device_lists=DeviceListUpdates(), device_lists=DeviceListUpdates(),
device_one_time_keys_count={}, device_one_time_keys_count=device_one_time_keys_count,
device_unused_fallback_key_types=[], device_unused_fallback_key_types=device_unused_fallback_key_types,
) )
@ -523,7 +527,28 @@ class SyncHandler:
logger.warning( logger.warning(
"Timed out waiting for worker to catch up. Returning empty response" "Timed out waiting for worker to catch up. Returning empty response"
) )
return SyncResult.empty(since_token) device_id = sync_config.device_id
one_time_keys_count: JsonMapping = {}
unused_fallback_key_types: List[str] = []
if device_id:
user_id = sync_config.user.to_string()
# TODO: We should have a way to let clients differentiate between the states of:
# * no change in OTK count since the provided since token
# * the server has zero OTKs left for this device
# Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
one_time_keys_count = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
unused_fallback_key_types = list(
await self.store.get_e2e_unused_fallback_key_types(
user_id, device_id
)
)
cache_context.should_cache = False # Don't cache empty responses
return SyncResult.empty(
since_token, one_time_keys_count, unused_fallback_key_types
)
# If we've spent significant time waiting to catch up, take it off # If we've spent significant time waiting to catch up, take it off
# the timeout. # the timeout.

View file

@ -36,7 +36,6 @@ from synapse.http.servlet import (
) )
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import log_kv, set_tag from synapse.logging.opentracing import log_kv, set_tag
from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
from synapse.rest.client._base import client_patterns, interactive_auth_handler from synapse.rest.client._base import client_patterns, interactive_auth_handler
from synapse.types import JsonDict, StreamToken from synapse.types import JsonDict, StreamToken
from synapse.util.cancellation import cancellable from synapse.util.cancellation import cancellable
@ -105,13 +104,8 @@ class KeyUploadServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self._clock = hs.get_clock()
if hs.config.worker.worker_app is None: self._store = hs.get_datastores().main
# if main process
self.key_uploader = self.e2e_keys_handler.upload_keys_for_user
else:
# then a worker
self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs)
async def on_POST( async def on_POST(
self, request: SynapseRequest, device_id: Optional[str] self, request: SynapseRequest, device_id: Optional[str]
@ -151,9 +145,10 @@ class KeyUploadServlet(RestServlet):
400, "To upload keys, you must pass device_id when authenticating" 400, "To upload keys, you must pass device_id when authenticating"
) )
result = await self.key_uploader( result = await self.e2e_keys_handler.upload_keys_for_user(
user_id=user_id, device_id=device_id, keys=body user_id=user_id, device_id=device_id, keys=body
) )
return 200, result return 200, result

View file

@ -547,6 +547,108 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# Room should still show up because it's newly_left during the from/to range # Room should still show up because it's newly_left during the from/to range
self.assertEqual(room_id_results, {room_id1}) self.assertEqual(room_id_results, {room_id1})
def test_no_from_token(self) -> None:
"""
Test that if we don't provide a `from_token`, we get all the rooms that we we're
joined to up to the `to_token`.
Providing `from_token` only really has the effect that it adds `newly_left`
rooms to the response.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
# We create the room with user2 so the room isn't left with no members when we
# leave and can still re-join.
room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
# Join room1
self.helper.join(room_id1, user1_id, tok=user1_tok)
# Join and leave the room2 before the `to_token`
self.helper.join(room_id2, user1_id, tok=user1_tok)
self.helper.leave(room_id2, user1_id, tok=user1_tok)
after_room1_token = self.event_sources.get_current_token()
# Join the room2 after we already have our tokens
self.helper.join(room_id2, user1_id, tok=user1_tok)
room_id_results = self.get_success(
self.sliding_sync_handler.get_sync_room_ids_for_user(
UserID.from_string(user1_id),
from_token=None,
to_token=after_room1_token,
)
)
# Only rooms we were joined to before the `to_token` should show up
self.assertEqual(room_id_results, {room_id1})
def test_from_token_ahead_of_to_token(self) -> None:
"""
Test when the provided `from_token` comes after the `to_token`. We should
basically expect the same result as having no `from_token`.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
# We create the room with user2 so the room isn't left with no members when we
# leave and can still re-join.
room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
room_id3 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
room_id4 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
# Join room1 before `before_room_token`
self.helper.join(room_id1, user1_id, tok=user1_tok)
# Join and leave the room2 before `before_room_token`
self.helper.join(room_id2, user1_id, tok=user1_tok)
self.helper.leave(room_id2, user1_id, tok=user1_tok)
# Note: These are purposely swapped. The `from_token` should come after
# the `to_token` in this test
to_token = self.event_sources.get_current_token()
# Join room2 after `before_room_token`
self.helper.join(room_id2, user1_id, tok=user1_tok)
# --------
# Join room3 after `before_room_token`
self.helper.join(room_id3, user1_id, tok=user1_tok)
# Join and leave the room4 after `before_room_token`
self.helper.join(room_id4, user1_id, tok=user1_tok)
self.helper.leave(room_id4, user1_id, tok=user1_tok)
# Note: These are purposely swapped. The `from_token` should come after the
# `to_token` in this test
from_token = self.event_sources.get_current_token()
# Join the room4 after we already have our tokens
self.helper.join(room_id4, user1_id, tok=user1_tok)
room_id_results = self.get_success(
self.sliding_sync_handler.get_sync_room_ids_for_user(
UserID.from_string(user1_id),
from_token=from_token,
to_token=to_token,
)
)
# Only rooms we were joined to before the `to_token` should show up
#
# There won't be any newly_left rooms because the `from_token` is ahead of the
# `to_token` and that range will give no membership changes to check.
self.assertEqual(room_id_results, {room_id1})
def test_leave_before_range_and_join_leave_after_to_token(self) -> None: def test_leave_before_range_and_join_leave_after_to_token(self) -> None:
""" """
Old left room shouldn't show up. But we're also testing that joining and leaving Old left room shouldn't show up. But we're also testing that joining and leaving