mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-21 17:15:38 +03:00
Notify when one-time keys are claimed
Fix https://github.com/element-hq/synapse/issues/17474
This commit is contained in:
parent
f6a3e5e1c2
commit
878d427d20
3 changed files with 224 additions and 30 deletions
|
@ -70,6 +70,7 @@ class E2eKeysHandler:
|
|||
self.is_mine = hs.is_mine
|
||||
self.clock = hs.get_clock()
|
||||
self._worker_lock_handler = hs.get_worker_locks_handler()
|
||||
self._notifier = hs.get_notifier()
|
||||
|
||||
federation_registry = hs.get_federation_registry()
|
||||
|
||||
|
@ -615,7 +616,7 @@ class E2eKeysHandler:
|
|||
3. Attempt to fetch fallback keys from the database.
|
||||
|
||||
Args:
|
||||
local_query: An iterable of tuples of (user ID, device ID, algorithm).
|
||||
local_query: An iterable of tuples of (user ID, device ID, algorithm, number of keys to claim).
|
||||
always_include_fallback_keys: True to always include fallback keys.
|
||||
|
||||
Returns:
|
||||
|
@ -629,6 +630,7 @@ class E2eKeysHandler:
|
|||
]
|
||||
|
||||
otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
|
||||
self._notifier.notify_one_time_keys_claimed(otk_results.keys())
|
||||
|
||||
# If the application services have not provided any keys via the C-S
|
||||
# API, query it directly for one-time keys.
|
||||
|
@ -639,6 +641,7 @@ class E2eKeysHandler:
|
|||
appservice_results,
|
||||
not_found,
|
||||
) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
|
||||
self._notifier.notify_one_time_keys_claimed(appservice_results.keys())
|
||||
else:
|
||||
appservice_results = {}
|
||||
|
||||
|
@ -693,6 +696,7 @@ class E2eKeysHandler:
|
|||
# For each user that does not have a one-time keys available, see if
|
||||
# there is a fallback key.
|
||||
fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query)
|
||||
self._notifier.notify_one_time_keys_claimed(fallback_results.keys())
|
||||
|
||||
# Return the results in order, each item from the input query should
|
||||
# only appear once in the combined list.
|
||||
|
|
|
@ -590,6 +590,8 @@ class Notifier:
|
|||
without waking up any of the normal user event streams"""
|
||||
self.notify_replication()
|
||||
|
||||
# FIXME: Should this be renamed to `wait_for_activity`? This listens for new events
|
||||
# and when one-time keys are claimed which doesn't correspond to an event.
|
||||
async def wait_for_events(
|
||||
self,
|
||||
user_id: str,
|
||||
|
@ -900,6 +902,40 @@ class Notifier:
|
|||
for cb in self._lock_released_callback:
|
||||
cb(instance_name, lock_name, lock_key)
|
||||
|
||||
def notify_one_time_keys_claimed(
|
||||
self,
|
||||
users: Union[StrCollection, Collection[UserID]],
|
||||
) -> None:
|
||||
"""
|
||||
Used by handlers to inform the notifier that a one-time key has been
|
||||
claimed
|
||||
"""
|
||||
# Bail early if there is nothing to do
|
||||
if not users:
|
||||
return
|
||||
|
||||
time_now_ms = self.clock.time_msec()
|
||||
current_token = self.event_sources.get_current_token()
|
||||
listeners: List["Deferred[StreamToken]"] = []
|
||||
for user in users:
|
||||
user_stream = self.user_to_user_stream.get(str(user))
|
||||
if user_stream is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
listeners.extend(
|
||||
user_stream.update_and_fetch_deferreds(current_token, time_now_ms)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to notify listener")
|
||||
|
||||
# We resolve all these deferreds in one go so that we only need to
|
||||
# call `PreserveLoggingContext` once, as it has a bunch of overhead
|
||||
# (to calculate performance stats)
|
||||
with PreserveLoggingContext():
|
||||
for listener in listeners:
|
||||
listener.callback(current_token)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class ReplicationNotifier:
|
||||
|
|
|
@ -11,9 +11,11 @@
|
|||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
#
|
||||
import enum
|
||||
import logging
|
||||
|
||||
from parameterized import parameterized_class
|
||||
from parameterized import parameterized, parameterized_class
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
|
@ -29,6 +31,12 @@ from tests.server import TimedOutException
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class E2eeBumpAction(enum.Enum):
|
||||
device_lists = enum.auto()
|
||||
one_time_keys = enum.auto()
|
||||
fallback_one_time_keys = enum.auto()
|
||||
|
||||
|
||||
# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
|
||||
# foreground update for
|
||||
# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
|
||||
|
@ -147,19 +155,38 @@ class SlidingSyncE2eeExtensionTestCase(SlidingSyncBase):
|
|||
[],
|
||||
)
|
||||
|
||||
def test_wait_for_new_data(self) -> None:
|
||||
@parameterized.expand(
|
||||
[
|
||||
(
|
||||
"bump_device_lists",
|
||||
E2eeBumpAction.device_lists,
|
||||
),
|
||||
(
|
||||
"bump_one_time_keys",
|
||||
E2eeBumpAction.one_time_keys,
|
||||
),
|
||||
(
|
||||
"bump_fallback_one_time_keys",
|
||||
E2eeBumpAction.fallback_one_time_keys,
|
||||
),
|
||||
]
|
||||
)
|
||||
def test_wait_for_new_data(
|
||||
self, test_description: str, bump_action: E2eeBumpAction
|
||||
) -> None:
|
||||
"""
|
||||
Test to make sure that the Sliding Sync request waits for new data to arrive.
|
||||
|
||||
(Only applies to incremental syncs with a `timeout` specified)
|
||||
"""
|
||||
test_device_id = "TESTDEVICE"
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
user1_tok = self.login(user1_id, "pass", device_id=test_device_id)
|
||||
user2_id = self.register_user("user2", "pass")
|
||||
user2_tok = self.login(user2_id, "pass")
|
||||
test_device_id = "TESTDEVICE"
|
||||
other_user_test_device_id = "OTHERUSERTESTDEVICE"
|
||||
user3_id = self.register_user("user3", "pass")
|
||||
user3_tok = self.login(user3_id, "pass", device_id=test_device_id)
|
||||
user3_tok = self.login(user3_id, "pass", device_id=other_user_test_device_id)
|
||||
|
||||
room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
|
||||
self.helper.join(room_id, user1_id, tok=user1_tok)
|
||||
|
@ -186,34 +213,161 @@ class SlidingSyncE2eeExtensionTestCase(SlidingSyncBase):
|
|||
# Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
|
||||
with self.assertRaises(TimedOutException):
|
||||
channel.await_result(timeout_ms=5000)
|
||||
# Bump the device lists to trigger new results
|
||||
# Have user3 update their device list
|
||||
device_update_channel = self.make_request(
|
||||
"PUT",
|
||||
f"/devices/{test_device_id}",
|
||||
{
|
||||
"display_name": "New Device Name",
|
||||
},
|
||||
access_token=user3_tok,
|
||||
)
|
||||
self.assertEqual(
|
||||
device_update_channel.code, 200, device_update_channel.json_body
|
||||
)
|
||||
|
||||
if bump_action == E2eeBumpAction.device_lists:
|
||||
# Bump the device lists to trigger new results
|
||||
# Have user3 update their device list
|
||||
device_update_channel = self.make_request(
|
||||
"PUT",
|
||||
f"/devices/{other_user_test_device_id}",
|
||||
{
|
||||
"display_name": "New Device Name",
|
||||
},
|
||||
access_token=user3_tok,
|
||||
)
|
||||
self.assertEqual(
|
||||
device_update_channel.code, 200, device_update_channel.json_body
|
||||
)
|
||||
elif bump_action == E2eeBumpAction.one_time_keys:
|
||||
# Upload one time keys for the user/device
|
||||
keys: JsonDict = {
|
||||
"alg1:k1": "key1",
|
||||
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
|
||||
"alg2:k3": {"key": "key3"},
|
||||
}
|
||||
upload_keys_response = self.get_success(
|
||||
self.e2e_keys_handler.upload_keys_for_user(
|
||||
user1_id, test_device_id, {"one_time_keys": keys}
|
||||
)
|
||||
)
|
||||
self.assertDictEqual(
|
||||
upload_keys_response,
|
||||
{
|
||||
"one_time_key_counts": {
|
||||
"alg1": 1,
|
||||
"alg2": 2,
|
||||
# Note that "signed_curve25519" is always returned in key count responses
|
||||
# regardless of whether we uploaded any keys for it. This is necessary until
|
||||
# https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
|
||||
#
|
||||
# Also related:
|
||||
# https://github.com/element-hq/element-android/issues/3725 and
|
||||
# https://github.com/matrix-org/synapse/issues/10456
|
||||
"signed_curve25519": 0,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Claim one of those new keys
|
||||
self.get_success(
|
||||
self.e2e_keys_handler.claim_local_one_time_keys(
|
||||
local_query=[(user1_id, test_device_id, "alg2", 1)],
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: We should have a way to let clients differentiate between the states of:
|
||||
# * no change in OTK count since the provided since token
|
||||
# * the server has zero OTKs left for this device
|
||||
# Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
|
||||
elif bump_action == E2eeBumpAction.fallback_one_time_keys:
|
||||
# Upload a fallback key for the user/device
|
||||
self.get_success(
|
||||
self.e2e_keys_handler.upload_keys_for_user(
|
||||
user1_id,
|
||||
test_device_id,
|
||||
{
|
||||
"fallback_keys": {
|
||||
"alg1:k1": "fallback_key1",
|
||||
"alg2:k2": "fallback_key2",
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
# We should now have an unused alg1 and alg2 key
|
||||
fallback_res = self.get_success(
|
||||
self.store.get_e2e_unused_fallback_key_types(user1_id, test_device_id)
|
||||
)
|
||||
self.assertIncludes(
|
||||
set(fallback_res),
|
||||
{"alg1", "alg2"},
|
||||
exact=True,
|
||||
message=str(fallback_res),
|
||||
)
|
||||
|
||||
# Claim one of those fallback keys
|
||||
self.get_success(
|
||||
self.e2e_keys_handler.claim_local_one_time_keys(
|
||||
local_query=[(user1_id, test_device_id, "alg1", 1)],
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: We should have a way to let clients differentiate between the states of:
|
||||
# * no change in OTK count since the provided since token
|
||||
# * the server has zero OTKs left for this device
|
||||
# Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
|
||||
else:
|
||||
assert_never(bump_action)
|
||||
|
||||
# Should respond before the 10 second timeout
|
||||
channel.await_result(timeout_ms=3000)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# We should see the device list update
|
||||
self.assertEqual(
|
||||
channel.json_body["extensions"]["e2ee"]
|
||||
.get("device_lists", {})
|
||||
.get("changed"),
|
||||
[user3_id],
|
||||
)
|
||||
self.assertEqual(
|
||||
channel.json_body["extensions"]["e2ee"].get("device_lists", {}).get("left"),
|
||||
[],
|
||||
)
|
||||
# Check for the new data
|
||||
#
|
||||
if bump_action == E2eeBumpAction.device_lists:
|
||||
# We should see the device list update
|
||||
self.assertEqual(
|
||||
channel.json_body["extensions"]["e2ee"]
|
||||
.get("device_lists", {})
|
||||
.get("changed"),
|
||||
[user3_id],
|
||||
)
|
||||
self.assertEqual(
|
||||
channel.json_body["extensions"]["e2ee"]
|
||||
.get("device_lists", {})
|
||||
.get("left"),
|
||||
[],
|
||||
)
|
||||
elif bump_action == E2eeBumpAction.one_time_keys:
|
||||
# We should see the one-time key count change
|
||||
self.assertEqual(
|
||||
channel.json_body["extensions"]["e2ee"].get(
|
||||
"device_one_time_keys_count"
|
||||
),
|
||||
{
|
||||
"alg1": 1,
|
||||
# Note: This changed from 2 -> 1 since we claimed one of them
|
||||
"alg2": 1,
|
||||
# Note that "signed_curve25519" is always returned in key count responses
|
||||
# regardless of whether we uploaded any keys for it. This is necessary until
|
||||
# https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
|
||||
#
|
||||
# Also related:
|
||||
# https://github.com/element-hq/element-android/issues/3725 and
|
||||
# https://github.com/matrix-org/synapse/issues/10456
|
||||
"signed_curve25519": 0,
|
||||
},
|
||||
)
|
||||
elif bump_action == E2eeBumpAction.fallback_one_time_keys:
|
||||
# Check for the unused fallback key types
|
||||
self.assertIncludes(
|
||||
set(
|
||||
channel.json_body["extensions"]["e2ee"].get(
|
||||
"device_unused_fallback_key_types", []
|
||||
)
|
||||
),
|
||||
{"alg2"},
|
||||
exact=True,
|
||||
message=str(
|
||||
channel.json_body["extensions"]["e2ee"].get(
|
||||
"device_unused_fallback_key_types",
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
assert_never(bump_action)
|
||||
|
||||
def test_wait_for_new_data_timeout(self) -> None:
|
||||
"""
|
||||
|
|
Loading…
Reference in a new issue