Fix test inheritance

See https://github.com/element-hq/synapse/pull/17167#discussion_r1594517041
This commit is contained in:
Eric Eastwood 2024-05-16 17:04:26 -05:00
parent 7331401e89
commit b23abca9e7
4 changed files with 699 additions and 600 deletions

View file

@ -1,7 +1,281 @@
from tests.rest.client.test_sendtodevice_base import SendToDeviceTestCaseBase
from tests.unittest import HomeserverTestCase
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EduTypes
from synapse.rest import admin
from synapse.rest.client import login, sendtodevice, sync
from synapse.server import HomeServer
from synapse.util import Clock
from tests.unittest import HomeserverTestCase, override_config
class SendToDeviceTestCase(SendToDeviceTestCaseBase, HomeserverTestCase):
# See SendToDeviceTestCaseBase for tests
class NotTested:
"""
We nest the base test class to avoid the tests being run twice by the test runner
when we share/import these tests in other files. Without this, Twisted trial throws
a `KeyError` in the reporter when using multiple jobs (`poetry run trial --jobs=6`).
"""
class SendToDeviceTestCaseBase(HomeserverTestCase):
"""
Test `/sendToDevice` will deliver messages across to people receiving them over `/sync`.
In order to run the tests, inherit from this base-class with `HomeserverTestCase`, e.g.
`class SendToDeviceTestCase(SendToDeviceTestCase, HomeserverTestCase)`
"""
servlets = [
admin.register_servlets,
login.register_servlets,
sendtodevice.register_servlets,
sync.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.sync_endpoint = "/sync"
def test_user_to_user(self) -> None:
"""A to-device message from one user to another should get delivered"""
user1 = self.register_user("u1", "pass")
user1_tok = self.login("u1", "pass", "d1")
user2 = self.register_user("u2", "pass")
user2_tok = self.login("u2", "pass", "d2")
# send the message
test_msg = {"foo": "bar"}
chan = self.make_request(
"PUT",
"/_matrix/client/r0/sendToDevice/m.test/1234",
content={"messages": {user2: {"d2": test_msg}}},
access_token=user1_tok,
)
self.assertEqual(chan.code, 200, chan.result)
# check it appears
channel = self.make_request(
"GET", self.sync_endpoint, access_token=user2_tok
)
self.assertEqual(channel.code, 200, channel.result)
expected_result = {
"events": [
{
"sender": user1,
"type": "m.test",
"content": test_msg,
}
]
}
self.assertEqual(channel.json_body["to_device"], expected_result)
# it should re-appear if we do another sync because the to-device message is not
# deleted until we acknowledge it by sending a `?since=...` parameter in the
# next sync request corresponding to the `next_batch` value from the response.
channel = self.make_request(
"GET", self.sync_endpoint, access_token=user2_tok
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["to_device"], expected_result)
# it should *not* appear if we do an incremental sync
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={sync_token}",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(
channel.json_body.get("to_device", {}).get("events", []), []
)
@override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}})
def test_local_room_key_request(self) -> None:
"""m.room_key_request has special-casing; test from local user"""
user1 = self.register_user("u1", "pass")
user1_tok = self.login("u1", "pass", "d1")
user2 = self.register_user("u2", "pass")
user2_tok = self.login("u2", "pass", "d2")
# send three messages
for i in range(3):
chan = self.make_request(
"PUT",
f"/_matrix/client/r0/sendToDevice/m.room_key_request/{i}",
content={"messages": {user2: {"d2": {"idx": i}}}},
access_token=user1_tok,
)
self.assertEqual(chan.code, 200, chan.result)
# now sync: we should get two of the three (because burst_count=2)
channel = self.make_request(
"GET", self.sync_endpoint, access_token=user2_tok
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
for i in range(2):
self.assertEqual(
msgs[i],
{
"sender": user1,
"type": "m.room_key_request",
"content": {"idx": i},
},
)
sync_token = channel.json_body["next_batch"]
# ... time passes
self.reactor.advance(1)
# and we can send more messages
chan = self.make_request(
"PUT",
"/_matrix/client/r0/sendToDevice/m.room_key_request/3",
content={"messages": {user2: {"d2": {"idx": 3}}}},
access_token=user1_tok,
)
self.assertEqual(chan.code, 200, chan.result)
# ... which should arrive
channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={sync_token}",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 1)
self.assertEqual(
msgs[0],
{"sender": user1, "type": "m.room_key_request", "content": {"idx": 3}},
)
@override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}})
def test_remote_room_key_request(self) -> None:
"""m.room_key_request has special-casing; test from remote user"""
user2 = self.register_user("u2", "pass")
user2_tok = self.login("u2", "pass", "d2")
federation_registry = self.hs.get_federation_registry()
# send three messages
for i in range(3):
self.get_success(
federation_registry.on_edu(
EduTypes.DIRECT_TO_DEVICE,
"remote_server",
{
"sender": "@user:remote_server",
"type": "m.room_key_request",
"messages": {user2: {"d2": {"idx": i}}},
"message_id": f"{i}",
},
)
)
# now sync: we should get two of the three
channel = self.make_request(
"GET", self.sync_endpoint, access_token=user2_tok
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
for i in range(2):
self.assertEqual(
msgs[i],
{
"sender": "@user:remote_server",
"type": "m.room_key_request",
"content": {"idx": i},
},
)
sync_token = channel.json_body["next_batch"]
# ... time passes
self.reactor.advance(1)
# and we can send more messages
self.get_success(
federation_registry.on_edu(
EduTypes.DIRECT_TO_DEVICE,
"remote_server",
{
"sender": "@user:remote_server",
"type": "m.room_key_request",
"messages": {user2: {"d2": {"idx": 3}}},
"message_id": "3",
},
)
)
# ... which should arrive
channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={sync_token}",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 1)
self.assertEqual(
msgs[0],
{
"sender": "@user:remote_server",
"type": "m.room_key_request",
"content": {"idx": 3},
},
)
def test_limited_sync(self) -> None:
"""If a limited sync for to-devices happens the next /sync should respond immediately."""
self.register_user("u1", "pass")
user1_tok = self.login("u1", "pass", "d1")
user2 = self.register_user("u2", "pass")
user2_tok = self.login("u2", "pass", "d2")
# Do an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=user2_tok
)
self.assertEqual(channel.code, 200, channel.result)
sync_token = channel.json_body["next_batch"]
# Send 150 to-device messages. We limit to 100 in `/sync`
for i in range(150):
test_msg = {"foo": "bar"}
chan = self.make_request(
"PUT",
f"/_matrix/client/r0/sendToDevice/m.test/1234-{i}",
content={"messages": {user2: {"d2": test_msg}}},
access_token=user1_tok,
)
self.assertEqual(chan.code, 200, chan.result)
channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", [])
self.assertEqual(len(messages), 100)
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", [])
self.assertEqual(len(messages), 50)
class SendToDeviceTestCase(NotTested.SendToDeviceTestCaseBase):
# See SendToDeviceTestCaseBase above
pass

View file

@ -1,268 +0,0 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EduTypes
from synapse.rest import admin
from synapse.rest.client import login, sendtodevice, sync
from synapse.server import HomeServer
from synapse.util import Clock
from tests.unittest import HomeserverTestCase, override_config
class SendToDeviceTestCaseBase(HomeserverTestCase):
"""
Test `/sendToDevice` will deliver messages across to people receiving them over `/sync`.
In order to run the tests, inherit from this base-class with `HomeserverTestCase`, e.g.
`class SendToDeviceTestCase(SendToDeviceTestCase, HomeserverTestCase)`
"""
servlets = [
admin.register_servlets,
login.register_servlets,
sendtodevice.register_servlets,
sync.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.sync_endpoint = "/sync"
def test_user_to_user(self) -> None:
"""A to-device message from one user to another should get delivered"""
user1 = self.register_user("u1", "pass")
user1_tok = self.login("u1", "pass", "d1")
user2 = self.register_user("u2", "pass")
user2_tok = self.login("u2", "pass", "d2")
# send the message
test_msg = {"foo": "bar"}
chan = self.make_request(
"PUT",
"/_matrix/client/r0/sendToDevice/m.test/1234",
content={"messages": {user2: {"d2": test_msg}}},
access_token=user1_tok,
)
self.assertEqual(chan.code, 200, chan.result)
# check it appears
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
expected_result = {
"events": [
{
"sender": user1,
"type": "m.test",
"content": test_msg,
}
]
}
self.assertEqual(channel.json_body["to_device"], expected_result)
# it should re-appear if we do another sync because the to-device message is not
# deleted until we acknowledge it by sending a `?since=...` parameter in the
# next sync request corresponding to the `next_batch` value from the response.
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["to_device"], expected_result)
# it should *not* appear if we do an incremental sync
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
"GET", f"{self.sync_endpoint}?since={sync_token}", access_token=user2_tok
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), [])
@override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}})
def test_local_room_key_request(self) -> None:
"""m.room_key_request has special-casing; test from local user"""
user1 = self.register_user("u1", "pass")
user1_tok = self.login("u1", "pass", "d1")
user2 = self.register_user("u2", "pass")
user2_tok = self.login("u2", "pass", "d2")
# send three messages
for i in range(3):
chan = self.make_request(
"PUT",
f"/_matrix/client/r0/sendToDevice/m.room_key_request/{i}",
content={"messages": {user2: {"d2": {"idx": i}}}},
access_token=user1_tok,
)
self.assertEqual(chan.code, 200, chan.result)
# now sync: we should get two of the three (because burst_count=2)
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
for i in range(2):
self.assertEqual(
msgs[i],
{"sender": user1, "type": "m.room_key_request", "content": {"idx": i}},
)
sync_token = channel.json_body["next_batch"]
# ... time passes
self.reactor.advance(1)
# and we can send more messages
chan = self.make_request(
"PUT",
"/_matrix/client/r0/sendToDevice/m.room_key_request/3",
content={"messages": {user2: {"d2": {"idx": 3}}}},
access_token=user1_tok,
)
self.assertEqual(chan.code, 200, chan.result)
# ... which should arrive
channel = self.make_request(
"GET", f"{self.sync_endpoint}?since={sync_token}", access_token=user2_tok
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 1)
self.assertEqual(
msgs[0],
{"sender": user1, "type": "m.room_key_request", "content": {"idx": 3}},
)
@override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}})
def test_remote_room_key_request(self) -> None:
"""m.room_key_request has special-casing; test from remote user"""
user2 = self.register_user("u2", "pass")
user2_tok = self.login("u2", "pass", "d2")
federation_registry = self.hs.get_federation_registry()
# send three messages
for i in range(3):
self.get_success(
federation_registry.on_edu(
EduTypes.DIRECT_TO_DEVICE,
"remote_server",
{
"sender": "@user:remote_server",
"type": "m.room_key_request",
"messages": {user2: {"d2": {"idx": i}}},
"message_id": f"{i}",
},
)
)
# now sync: we should get two of the three
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
for i in range(2):
self.assertEqual(
msgs[i],
{
"sender": "@user:remote_server",
"type": "m.room_key_request",
"content": {"idx": i},
},
)
sync_token = channel.json_body["next_batch"]
# ... time passes
self.reactor.advance(1)
# and we can send more messages
self.get_success(
federation_registry.on_edu(
EduTypes.DIRECT_TO_DEVICE,
"remote_server",
{
"sender": "@user:remote_server",
"type": "m.room_key_request",
"messages": {user2: {"d2": {"idx": 3}}},
"message_id": "3",
},
)
)
# ... which should arrive
channel = self.make_request(
"GET", f"{self.sync_endpoint}?since={sync_token}", access_token=user2_tok
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 1)
self.assertEqual(
msgs[0],
{
"sender": "@user:remote_server",
"type": "m.room_key_request",
"content": {"idx": 3},
},
)
def test_limited_sync(self) -> None:
"""If a limited sync for to-devices happens the next /sync should respond immediately."""
self.register_user("u1", "pass")
user1_tok = self.login("u1", "pass", "d1")
user2 = self.register_user("u2", "pass")
user2_tok = self.login("u2", "pass", "d2")
# Do an initial sync
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
sync_token = channel.json_body["next_batch"]
# Send 150 to-device messages. We limit to 100 in `/sync`
for i in range(150):
test_msg = {"foo": "bar"}
chan = self.make_request(
"PUT",
f"/_matrix/client/r0/sendToDevice/m.test/1234-{i}",
content={"messages": {user2: {"d2": test_msg}}},
access_token=user1_tok,
)
self.assertEqual(chan.code, 200, chan.result)
channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", [])
self.assertEqual(len(messages), 100)
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", [])
self.assertEqual(len(messages), 50)

View file

@ -4,18 +4,18 @@ from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
# TODO: Uncomment this line when we have a pattern to share tests across files, see
# https://github.com/element-hq/synapse/pull/17167#discussion_r1594517041
#
# from tests.rest.client.test_sync import DeviceListSyncTestCase
# from tests.rest.client.test_sync import DeviceOneTimeKeysSyncTestCase
# from tests.rest.client.test_sync import DeviceUnusedFallbackKeySyncTestCase
from tests.rest.client.test_sendtodevice_base import SendToDeviceTestCaseBase
from tests.unittest import HomeserverTestCase
from tests.rest.client.test_sendtodevice import NotTested as SendToDeviceNotTested
from tests.rest.client.test_sync import NotTested as SyncNotTested
# Test To-Device messages working correctly with the `/sync/e2ee` endpoint (`to_device`)
class SlidingSyncE2eeSendToDeviceTestCase(SendToDeviceTestCaseBase, HomeserverTestCase):
class SlidingSyncE2eeSendToDeviceTestCase(
SendToDeviceNotTested.SendToDeviceTestCaseBase
):
"""
Test To-Device messages working correctly with the `/sync/e2ee` endpoint
(`to_device`)
"""
def default_config(self) -> JsonDict:
config = super().default_config()
# Enable sliding sync
@ -23,7 +23,71 @@ class SlidingSyncE2eeSendToDeviceTestCase(SendToDeviceTestCaseBase, HomeserverTe
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
# Use the Sliding Sync `/sync/e2ee` endpoint
self.sync_endpoint = "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee"
# See SendToDeviceTestCaseBase for tests
class SlidingSyncE2eeDeviceListSyncTestCase(SyncNotTested.DeviceListSyncTestCaseBase):
"""
Test device lists working correctly with the `/sync/e2ee` endpoint (`device_lists`)
"""
def default_config(self) -> JsonDict:
config = super().default_config()
# Enable sliding sync
config["experimental_features"] = {"msc3575_enabled": True}
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
# Use the Sliding Sync `/sync/e2ee` endpoint
self.sync_endpoint = "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee"
# See DeviceListSyncTestCaseBase for tests
class SlidingSyncE2eeDeviceOneTimeKeysSyncTestCase(
SyncNotTested.DeviceOneTimeKeysSyncTestCaseBase
):
"""
Test device one time keys working correctly with the `/sync/e2ee` endpoint
(`device_one_time_keys_count`)
"""
def default_config(self) -> JsonDict:
config = super().default_config()
# Enable sliding sync
config["experimental_features"] = {"msc3575_enabled": True}
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
# Use the Sliding Sync `/sync/e2ee` endpoint
self.sync_endpoint = "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee"
# See DeviceOneTimeKeysSyncTestCaseBase for tests
class SlidingSyncE2eeDeviceUnusedFallbackKeySyncTestCase(
SyncNotTested.DeviceUnusedFallbackKeySyncTestCaseBase
):
"""
Test device unused fallback key types working correctly with the `/sync/e2ee`
endpoint (`device_unused_fallback_key_types`)
"""
def default_config(self) -> JsonDict:
config = super().default_config()
# Enable sliding sync
config["experimental_features"] = {"msc3575_enabled": True}
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
# Use the Sliding Sync `/sync/e2ee` endpoint
self.sync_endpoint = "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee"
# See DeviceUnusedFallbackKeySyncTestCaseBase for tests

View file

@ -688,367 +688,396 @@ class SyncCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.json_body)
class DeviceListSyncTestCase(unittest.HomeserverTestCase):
"""Tests regarding device list (`device_lists`) changes."""
class NotTested:
"""
We nest the base test class to avoid the tests being run twice by the test runner
when we share/import these tests in other files. Without this, Twisted trial throws
a `KeyError` in the reporter when using multiple jobs (`poetry run trial --jobs=6`).
"""
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
room.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
class DeviceListSyncTestCaseBase(unittest.HomeserverTestCase):
"""Tests regarding device list (`device_lists`) changes."""
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.sync_endpoint = "/sync"
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
room.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
def test_receiving_local_device_list_changes(self) -> None:
"""Tests that a local users that share a room receive each other's device list
changes.
"""
# Register two users
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.sync_endpoint = "/sync"
bob_user_id = self.register_user("bob", "ponyponypony")
bob_access_token = self.login(bob_user_id, "ponyponypony")
def test_receiving_local_device_list_changes(self) -> None:
"""Tests that a local users that share a room receive each other's device list
changes.
"""
# Register two users
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# Create a room for them to coexist peacefully in
new_room_id = self.helper.create_room_as(
alice_user_id, is_public=True, tok=alice_access_token
)
self.assertIsNotNone(new_room_id)
bob_user_id = self.register_user("bob", "ponyponypony")
bob_access_token = self.login(bob_user_id, "ponyponypony")
# Have Bob join the room
self.helper.invite(
new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
)
self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
# Create a room for them to coexist peacefully in
new_room_id = self.helper.create_room_as(
alice_user_id, is_public=True, tok=alice_access_token
)
self.assertIsNotNone(new_room_id)
# Now have Bob initiate an initial sync (in order to get a since token)
channel = self.make_request(
"GET",
self.sync_endpoint,
access_token=bob_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
next_batch_token = channel.json_body["next_batch"]
# Have Bob join the room
self.helper.invite(
new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
)
self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
# ...and then an incremental sync. This should block until the sync stream is woken up,
# which we hope will happen as a result of Alice updating their device list.
bob_sync_channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={next_batch_token}&timeout=30000",
access_token=bob_access_token,
# Start the request, then continue on.
await_result=False,
)
# Now have Bob initiate an initial sync (in order to get a since token)
channel = self.make_request(
"GET",
self.sync_endpoint,
access_token=bob_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
next_batch_token = channel.json_body["next_batch"]
# Have alice update their device list
channel = self.make_request(
"PUT",
f"/devices/{test_device_id}",
{
"display_name": "New Device Name",
},
access_token=alice_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
# ...and then an incremental sync. This should block until the sync stream is woken up,
# which we hope will happen as a result of Alice updating their device list.
bob_sync_channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={next_batch_token}&timeout=30000",
access_token=bob_access_token,
# Start the request, then continue on.
await_result=False,
)
# Check that bob's incremental sync contains the updated device list.
# If not, the client would only receive the device list update on the
# *next* sync.
bob_sync_channel.await_result()
self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
# Have alice update their device list
channel = self.make_request(
"PUT",
f"/devices/{test_device_id}",
{
"display_name": "New Device Name",
},
access_token=alice_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
"changed", []
)
self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
# Check that bob's incremental sync contains the updated device list.
# If not, the client would only receive the device list update on the
# *next* sync.
bob_sync_channel.await_result()
self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
def test_not_receiving_local_device_list_changes(self) -> None:
"""Tests a local users DO NOT receive device updates from each other if they do not
share a room.
"""
# Register two users
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
changed_device_lists = bob_sync_channel.json_body.get(
"device_lists", {}
).get("changed", [])
self.assertIn(
alice_user_id, changed_device_lists, bob_sync_channel.json_body
)
bob_user_id = self.register_user("bob", "ponyponypony")
bob_access_token = self.login(bob_user_id, "ponyponypony")
def test_not_receiving_local_device_list_changes(self) -> None:
"""Tests a local users DO NOT receive device updates from each other if they do not
share a room.
"""
# Register two users
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# These users do not share a room. They are lonely.
bob_user_id = self.register_user("bob", "ponyponypony")
bob_access_token = self.login(bob_user_id, "ponyponypony")
# Have Bob initiate an initial sync (in order to get a since token)
channel = self.make_request(
"GET",
self.sync_endpoint,
access_token=bob_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
next_batch_token = channel.json_body["next_batch"]
# These users do not share a room. They are lonely.
# ...and then an incremental sync. This should block until the sync stream is woken up,
# which we hope will happen as a result of Alice updating their device list.
bob_sync_channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={next_batch_token}&timeout=1000",
access_token=bob_access_token,
# Start the request, then continue on.
await_result=False,
)
# Have Bob initiate an initial sync (in order to get a since token)
channel = self.make_request(
"GET",
self.sync_endpoint,
access_token=bob_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
next_batch_token = channel.json_body["next_batch"]
# Have alice update their device list
channel = self.make_request(
"PUT",
f"/devices/{test_device_id}",
{
"display_name": "New Device Name",
},
access_token=alice_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
# ...and then an incremental sync. This should block until the sync stream is woken up,
# which we hope will happen as a result of Alice updating their device list.
bob_sync_channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={next_batch_token}&timeout=1000",
access_token=bob_access_token,
# Start the request, then continue on.
await_result=False,
)
# Check that bob's incremental sync does not contain the updated device list.
bob_sync_channel.await_result()
self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
# Have alice update their device list
channel = self.make_request(
"PUT",
f"/devices/{test_device_id}",
{
"display_name": "New Device Name",
},
access_token=alice_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
"changed", []
)
self.assertNotIn(
alice_user_id, changed_device_lists, bob_sync_channel.json_body
)
# Check that bob's incremental sync does not contain the updated device list.
bob_sync_channel.await_result()
self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None:
"""Tests that a user with no rooms still receives their own device list updates"""
test_device_id = "TESTDEVICE"
changed_device_lists = bob_sync_channel.json_body.get(
"device_lists", {}
).get("changed", [])
self.assertNotIn(
alice_user_id, changed_device_lists, bob_sync_channel.json_body
)
# Register a user and login, creating a device
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None:
"""Tests that a user with no rooms still receives their own device list updates"""
test_device_id = "TESTDEVICE"
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
self.assertEqual(channel.code, 200, channel.json_body)
next_batch = channel.json_body["next_batch"]
# Register a user and login, creating a device
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# Now, make an incremental sync request.
# It won't return until something has happened
incremental_sync_channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={next_batch}&timeout=30000",
access_token=alice_access_token,
await_result=False,
)
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
self.assertEqual(channel.code, 200, channel.json_body)
next_batch = channel.json_body["next_batch"]
# Change our device's display name
channel = self.make_request(
"PUT",
f"devices/{test_device_id}",
{
"display_name": "freeze ray",
},
access_token=alice_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Now, make an incremental sync request.
# It won't return until something has happened
incremental_sync_channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={next_batch}&timeout=30000",
access_token=alice_access_token,
await_result=False,
)
# The sync should now have returned
incremental_sync_channel.await_result(timeout_ms=20000)
self.assertEqual(incremental_sync_channel.code, 200, channel.json_body)
# Change our device's display name
channel = self.make_request(
"PUT",
f"devices/{test_device_id}",
{
"display_name": "freeze ray",
},
access_token=alice_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
# We should have received notification that the (user's) device has changed
device_list_changes = incremental_sync_channel.json_body.get(
"device_lists", {}
).get("changed", [])
# The sync should now have returned
incremental_sync_channel.await_result(timeout_ms=20000)
self.assertEqual(incremental_sync_channel.code, 200, channel.json_body)
self.assertIn(
alice_user_id, device_list_changes, incremental_sync_channel.json_body
)
# We should have received notification that the (user's) device has changed
device_list_changes = incremental_sync_channel.json_body.get(
"device_lists", {}
).get("changed", [])
self.assertIn(
alice_user_id, device_list_changes, incremental_sync_channel.json_body
)
class DeviceOneTimeKeysSyncTestCase(unittest.HomeserverTestCase):
"""Tests regarding device one time keys (`device_one_time_keys_count`) changes."""
class DeviceOneTimeKeysSyncTestCaseBase(unittest.HomeserverTestCase):
"""Tests regarding device one time keys (`device_one_time_keys_count`) changes."""
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.sync_endpoint = "/sync"
self.e2e_keys_handler = hs.get_e2e_keys_handler()
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.sync_endpoint = "/sync"
self.e2e_keys_handler = hs.get_e2e_keys_handler()
def test_no_device_one_time_keys(self) -> None:
"""
Tests when no one time keys set, it still has the default `signed_curve25519` in
`device_one_time_keys_count`
"""
test_device_id = "TESTDEVICE"
def test_no_device_one_time_keys(self) -> None:
"""
Tests when no one time keys set, it still has the default `signed_curve25519` in
`device_one_time_keys_count`
"""
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
self.assertEqual(channel.code, 200, channel.json_body)
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for those one time key counts
self.assertDictEqual(
channel.json_body["device_one_time_keys_count"],
# Check for those one time key counts
self.assertDictEqual(
channel.json_body["device_one_time_keys_count"],
# Note that "signed_curve25519" is always returned in key count responses
# regardless of whether we uploaded any keys for it. This is necessary until
# https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
{"signed_curve25519": 0},
channel.json_body["device_one_time_keys_count"],
)
def test_returns_device_one_time_keys(self) -> None:
"""
Tests that one time keys for the device/user are counted correctly in the `/sync`
response
"""
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# Upload one time keys for the user/device
keys: JsonDict = {
"alg1:k1": "key1",
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
"alg2:k3": {"key": "key3"},
}
res = self.get_success(
self.e2e_keys_handler.upload_keys_for_user(
alice_user_id, test_device_id, {"one_time_keys": keys}
)
)
# Note that "signed_curve25519" is always returned in key count responses
# regardless of whether we uploaded any keys for it. This is necessary until
# https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
{"signed_curve25519": 0},
channel.json_body["device_one_time_keys_count"],
)
def test_returns_device_one_time_keys(self) -> None:
"""
Tests that one time keys for the device/user are counted correctly in the `/sync`
response
"""
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# Upload one time keys for the user/device
keys: JsonDict = {
"alg1:k1": "key1",
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
"alg2:k3": {"key": "key3"},
}
res = self.get_success(
self.e2e_keys_handler.upload_keys_for_user(
alice_user_id, test_device_id, {"one_time_keys": keys}
self.assertDictEqual(
res,
{"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}},
)
)
# Note that "signed_curve25519" is always returned in key count responses
# regardless of whether we uploaded any keys for it. This is necessary until
# https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
self.assertDictEqual(
res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
)
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for those one time key counts
self.assertDictEqual(
channel.json_body["device_one_time_keys_count"],
{"alg1": 1, "alg2": 2, "signed_curve25519": 0},
channel.json_body["device_one_time_keys_count"],
)
class DeviceUnusedFallbackKeySyncTestCase(unittest.HomeserverTestCase):
"""Tests regarding device one time keys (`device_unused_fallback_key_types`) changes."""
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.sync_endpoint = "/sync"
self.store = self.hs.get_datastores().main
self.e2e_keys_handler = hs.get_e2e_keys_handler()
def test_no_device_unused_fallback_key(self) -> None:
"""
Test when no unused fallback key is set, it just returns an empty list. The MSC
says "The device_unused_fallback_key_types parameter must be present if the
server supports fallback keys.",
https://github.com/matrix-org/matrix-spec-proposals/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
"""
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for those one time key counts
self.assertListEqual(
channel.json_body["device_unused_fallback_key_types"],
[],
channel.json_body["device_unused_fallback_key_types"],
)
def test_returns_device_one_time_keys(self) -> None:
"""
Tests that device unused fallback key type is returned correctly in the `/sync`
"""
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# We shouldn't have any unused fallback keys yet
res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
)
self.assertEqual(res, [])
# Upload a fallback key for the user/device
fallback_key = {"alg1:k1": "fallback_key1"}
self.get_success(
self.e2e_keys_handler.upload_keys_for_user(
alice_user_id,
test_device_id,
{"fallback_keys": fallback_key},
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
)
# We should now have an unused alg1 key
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
)
self.assertEqual(fallback_res, ["alg1"], fallback_res)
self.assertEqual(channel.code, 200, channel.json_body)
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for those one time key counts
self.assertDictEqual(
channel.json_body["device_one_time_keys_count"],
{"alg1": 1, "alg2": 2, "signed_curve25519": 0},
channel.json_body["device_one_time_keys_count"],
)
# Check for the unused fallback key types
self.assertListEqual(
channel.json_body["device_unused_fallback_key_types"],
["alg1"],
channel.json_body["device_unused_fallback_key_types"],
)
class DeviceUnusedFallbackKeySyncTestCaseBase(unittest.HomeserverTestCase):
"""Tests regarding device one time keys (`device_unused_fallback_key_types`) changes."""
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.sync_endpoint = "/sync"
self.store = self.hs.get_datastores().main
self.e2e_keys_handler = hs.get_e2e_keys_handler()
def test_no_device_unused_fallback_key(self) -> None:
"""
Test when no unused fallback key is set, it just returns an empty list. The MSC
says "The device_unused_fallback_key_types parameter must be present if the
server supports fallback keys.",
https://github.com/matrix-org/matrix-spec-proposals/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
"""
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for those one time key counts
self.assertListEqual(
channel.json_body["device_unused_fallback_key_types"],
[],
channel.json_body["device_unused_fallback_key_types"],
)
def test_returns_device_one_time_keys(self) -> None:
"""
Tests that device unused fallback key type is returned correctly in the `/sync`
"""
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# We shouldn't have any unused fallback keys yet
res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(
alice_user_id, test_device_id
)
)
self.assertEqual(res, [])
# Upload a fallback key for the user/device
fallback_key = {"alg1:k1": "fallback_key1"}
self.get_success(
self.e2e_keys_handler.upload_keys_for_user(
alice_user_id,
test_device_id,
{"fallback_keys": fallback_key},
)
)
# We should now have an unused alg1 key
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(
alice_user_id, test_device_id
)
)
self.assertEqual(fallback_res, ["alg1"], fallback_res)
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for the unused fallback key types
self.assertListEqual(
channel.json_body["device_unused_fallback_key_types"],
["alg1"],
channel.json_body["device_unused_fallback_key_types"],
)
class DeviceListSyncTestCase(NotTested.DeviceListSyncTestCaseBase):
# See DeviceListSyncTestCaseBase above
pass
class DeviceOneTimeKeysSyncTestCase(NotTested.DeviceOneTimeKeysSyncTestCaseBase):
# See DeviceOneTimeKeysSyncTestCaseBase above
pass
class DeviceUnusedFallbackKeySyncTestCase(
NotTested.DeviceUnusedFallbackKeySyncTestCaseBase
):
# See DeviceUnusedFallbackKeySyncTestCaseBase above
pass
class ExcludeRoomTestCase(unittest.HomeserverTestCase):