Notify when one-time keys are claimed

Fix https://github.com/element-hq/synapse/issues/17474
This commit is contained in:
Eric Eastwood 2024-10-10 15:06:15 -05:00
parent f6a3e5e1c2
commit 878d427d20
3 changed files with 224 additions and 30 deletions

View file

@ -70,6 +70,7 @@ class E2eKeysHandler:
self.is_mine = hs.is_mine
self.clock = hs.get_clock()
self._worker_lock_handler = hs.get_worker_locks_handler()
self._notifier = hs.get_notifier()
federation_registry = hs.get_federation_registry()
@ -615,7 +616,7 @@ class E2eKeysHandler:
3. Attempt to fetch fallback keys from the database.
Args:
local_query: An iterable of tuples of (user ID, device ID, algorithm).
local_query: An iterable of tuples of (user ID, device ID, algorithm, number of keys to claim).
always_include_fallback_keys: True to always include fallback keys.
Returns:
@ -629,6 +630,7 @@ class E2eKeysHandler:
]
otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
self._notifier.notify_one_time_keys_claimed(otk_results.keys())
# If the application services have not provided any keys via the C-S
# API, query it directly for one-time keys.
@ -639,6 +641,7 @@ class E2eKeysHandler:
appservice_results,
not_found,
) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
self._notifier.notify_one_time_keys_claimed(appservice_results.keys())
else:
appservice_results = {}
@ -693,6 +696,7 @@ class E2eKeysHandler:
# For each user that does not have a one-time keys available, see if
# there is a fallback key.
fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query)
self._notifier.notify_one_time_keys_claimed(fallback_results.keys())
# Return the results in order, each item from the input query should
# only appear once in the combined list.

View file

@ -590,6 +590,8 @@ class Notifier:
without waking up any of the normal user event streams"""
self.notify_replication()
# FIXME: Should this be renamed to `wait_for_activity`? This listens for new events
# and when one-time keys are claimed which doesn't correspond to an event.
async def wait_for_events(
self,
user_id: str,
@ -900,6 +902,40 @@ class Notifier:
for cb in self._lock_released_callback:
cb(instance_name, lock_name, lock_key)
def notify_one_time_keys_claimed(
self,
users: Union[StrCollection, Collection[UserID]],
) -> None:
"""
Used by handlers to inform the notifier that a one-time key has been
claimed
"""
# Bail early if there is nothing to do
if not users:
return
time_now_ms = self.clock.time_msec()
current_token = self.event_sources.get_current_token()
listeners: List["Deferred[StreamToken]"] = []
for user in users:
user_stream = self.user_to_user_stream.get(str(user))
if user_stream is None:
continue
try:
listeners.extend(
user_stream.update_and_fetch_deferreds(current_token, time_now_ms)
)
except Exception:
logger.exception("Failed to notify listener")
# We resolve all these deferreds in one go so that we only need to
# call `PreserveLoggingContext` once, as it has a bunch of overhead
# (to calculate performance stats)
with PreserveLoggingContext():
for listener in listeners:
listener.callback(current_token)
@attr.s(auto_attribs=True)
class ReplicationNotifier:

View file

@ -11,9 +11,11 @@
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
import enum
import logging
from parameterized import parameterized_class
from parameterized import parameterized, parameterized_class
from typing_extensions import assert_never
from twisted.test.proto_helpers import MemoryReactor
@ -29,6 +31,12 @@ from tests.server import TimedOutException
logger = logging.getLogger(__name__)
class E2eeBumpAction(enum.Enum):
device_lists = enum.auto()
one_time_keys = enum.auto()
fallback_one_time_keys = enum.auto()
# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
# foreground update for
# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
@ -147,19 +155,38 @@ class SlidingSyncE2eeExtensionTestCase(SlidingSyncBase):
[],
)
def test_wait_for_new_data(self) -> None:
@parameterized.expand(
[
(
"bump_device_lists",
E2eeBumpAction.device_lists,
),
(
"bump_one_time_keys",
E2eeBumpAction.one_time_keys,
),
(
"bump_fallback_one_time_keys",
E2eeBumpAction.fallback_one_time_keys,
),
]
)
def test_wait_for_new_data(
self, test_description: str, bump_action: E2eeBumpAction
) -> None:
"""
Test to make sure that the Sliding Sync request waits for new data to arrive.
(Only applies to incremental syncs with a `timeout` specified)
"""
test_device_id = "TESTDEVICE"
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user1_tok = self.login(user1_id, "pass", device_id=test_device_id)
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
test_device_id = "TESTDEVICE"
other_user_test_device_id = "OTHERUSERTESTDEVICE"
user3_id = self.register_user("user3", "pass")
user3_tok = self.login(user3_id, "pass", device_id=test_device_id)
user3_tok = self.login(user3_id, "pass", device_id=other_user_test_device_id)
room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
self.helper.join(room_id, user1_id, tok=user1_tok)
@ -186,11 +213,13 @@ class SlidingSyncE2eeExtensionTestCase(SlidingSyncBase):
# Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
with self.assertRaises(TimedOutException):
channel.await_result(timeout_ms=5000)
if bump_action == E2eeBumpAction.device_lists:
# Bump the device lists to trigger new results
# Have user3 update their device list
device_update_channel = self.make_request(
"PUT",
f"/devices/{test_device_id}",
f"/devices/{other_user_test_device_id}",
{
"display_name": "New Device Name",
},
@ -199,10 +228,95 @@ class SlidingSyncE2eeExtensionTestCase(SlidingSyncBase):
self.assertEqual(
device_update_channel.code, 200, device_update_channel.json_body
)
elif bump_action == E2eeBumpAction.one_time_keys:
# Upload one time keys for the user/device
keys: JsonDict = {
"alg1:k1": "key1",
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
"alg2:k3": {"key": "key3"},
}
upload_keys_response = self.get_success(
self.e2e_keys_handler.upload_keys_for_user(
user1_id, test_device_id, {"one_time_keys": keys}
)
)
self.assertDictEqual(
upload_keys_response,
{
"one_time_key_counts": {
"alg1": 1,
"alg2": 2,
# Note that "signed_curve25519" is always returned in key count responses
# regardless of whether we uploaded any keys for it. This is necessary until
# https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
#
# Also related:
# https://github.com/element-hq/element-android/issues/3725 and
# https://github.com/matrix-org/synapse/issues/10456
"signed_curve25519": 0,
}
},
)
# Claim one of those new keys
self.get_success(
self.e2e_keys_handler.claim_local_one_time_keys(
local_query=[(user1_id, test_device_id, "alg2", 1)],
always_include_fallback_keys=False,
)
)
# TODO: We should have a way to let clients differentiate between the states of:
# * no change in OTK count since the provided since token
# * the server has zero OTKs left for this device
# Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
elif bump_action == E2eeBumpAction.fallback_one_time_keys:
# Upload a fallback key for the user/device
self.get_success(
self.e2e_keys_handler.upload_keys_for_user(
user1_id,
test_device_id,
{
"fallback_keys": {
"alg1:k1": "fallback_key1",
"alg2:k2": "fallback_key2",
}
},
)
)
# We should now have an unused alg1 and alg2 key
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(user1_id, test_device_id)
)
self.assertIncludes(
set(fallback_res),
{"alg1", "alg2"},
exact=True,
message=str(fallback_res),
)
# Claim one of those fallback keys
self.get_success(
self.e2e_keys_handler.claim_local_one_time_keys(
local_query=[(user1_id, test_device_id, "alg1", 1)],
always_include_fallback_keys=False,
)
)
# TODO: We should have a way to let clients differentiate between the states of:
# * no change in OTK count since the provided since token
# * the server has zero OTKs left for this device
# Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
else:
assert_never(bump_action)
# Should respond before the 10 second timeout
channel.await_result(timeout_ms=3000)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for the new data
#
if bump_action == E2eeBumpAction.device_lists:
# We should see the device list update
self.assertEqual(
channel.json_body["extensions"]["e2ee"]
@ -211,9 +325,49 @@ class SlidingSyncE2eeExtensionTestCase(SlidingSyncBase):
[user3_id],
)
self.assertEqual(
channel.json_body["extensions"]["e2ee"].get("device_lists", {}).get("left"),
channel.json_body["extensions"]["e2ee"]
.get("device_lists", {})
.get("left"),
[],
)
elif bump_action == E2eeBumpAction.one_time_keys:
# We should see the one-time key count change
self.assertEqual(
channel.json_body["extensions"]["e2ee"].get(
"device_one_time_keys_count"
),
{
"alg1": 1,
# Note: This changed from 2 -> 1 since we claimed one of them
"alg2": 1,
# Note that "signed_curve25519" is always returned in key count responses
# regardless of whether we uploaded any keys for it. This is necessary until
# https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
#
# Also related:
# https://github.com/element-hq/element-android/issues/3725 and
# https://github.com/matrix-org/synapse/issues/10456
"signed_curve25519": 0,
},
)
elif bump_action == E2eeBumpAction.fallback_one_time_keys:
# Check for the unused fallback key types
self.assertIncludes(
set(
channel.json_body["extensions"]["e2ee"].get(
"device_unused_fallback_key_types", []
)
),
{"alg2"},
exact=True,
message=str(
channel.json_body["extensions"]["e2ee"].get(
"device_unused_fallback_key_types",
)
),
)
else:
assert_never(bump_action)
def test_wait_for_new_data_timeout(self) -> None:
"""