Fix test inheritance

See https://github.com/element-hq/synapse/pull/17167#discussion_r1594517041
This commit is contained in:
Eric Eastwood 2024-05-16 17:04:26 -05:00
parent 7331401e89
commit b23abca9e7
4 changed files with 699 additions and 600 deletions

View file

@ -1,7 +1,281 @@
from tests.rest.client.test_sendtodevice_base import SendToDeviceTestCaseBase from twisted.test.proto_helpers import MemoryReactor
from tests.unittest import HomeserverTestCase
from synapse.api.constants import EduTypes
from synapse.rest import admin
from synapse.rest.client import login, sendtodevice, sync
from synapse.server import HomeServer
from synapse.util import Clock
from tests.unittest import HomeserverTestCase, override_config
class SendToDeviceTestCase(SendToDeviceTestCaseBase, HomeserverTestCase): class NotTested:
# See SendToDeviceTestCaseBase for tests """
We nest the base test class to avoid the tests being run twice by the test runner
when we share/import these tests in other files. Without this, Twisted trial throws
a `KeyError` in the reporter when using multiple jobs (`poetry run trial --jobs=6`).
"""
class SendToDeviceTestCaseBase(HomeserverTestCase):
"""
Test `/sendToDevice` will deliver messages across to people receiving them over `/sync`.
In order to run the tests, inherit from this base-class with `HomeserverTestCase`, e.g.
`class SendToDeviceTestCase(SendToDeviceTestCase, HomeserverTestCase)`
"""
servlets = [
admin.register_servlets,
login.register_servlets,
sendtodevice.register_servlets,
sync.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.sync_endpoint = "/sync"
def test_user_to_user(self) -> None:
"""A to-device message from one user to another should get delivered"""
user1 = self.register_user("u1", "pass")
user1_tok = self.login("u1", "pass", "d1")
user2 = self.register_user("u2", "pass")
user2_tok = self.login("u2", "pass", "d2")
# send the message
test_msg = {"foo": "bar"}
chan = self.make_request(
"PUT",
"/_matrix/client/r0/sendToDevice/m.test/1234",
content={"messages": {user2: {"d2": test_msg}}},
access_token=user1_tok,
)
self.assertEqual(chan.code, 200, chan.result)
# check it appears
channel = self.make_request(
"GET", self.sync_endpoint, access_token=user2_tok
)
self.assertEqual(channel.code, 200, channel.result)
expected_result = {
"events": [
{
"sender": user1,
"type": "m.test",
"content": test_msg,
}
]
}
self.assertEqual(channel.json_body["to_device"], expected_result)
# it should re-appear if we do another sync because the to-device message is not
# deleted until we acknowledge it by sending a `?since=...` parameter in the
# next sync request corresponding to the `next_batch` value from the response.
channel = self.make_request(
"GET", self.sync_endpoint, access_token=user2_tok
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["to_device"], expected_result)
# it should *not* appear if we do an incremental sync
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={sync_token}",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(
channel.json_body.get("to_device", {}).get("events", []), []
)
@override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}})
def test_local_room_key_request(self) -> None:
"""m.room_key_request has special-casing; test from local user"""
user1 = self.register_user("u1", "pass")
user1_tok = self.login("u1", "pass", "d1")
user2 = self.register_user("u2", "pass")
user2_tok = self.login("u2", "pass", "d2")
# send three messages
for i in range(3):
chan = self.make_request(
"PUT",
f"/_matrix/client/r0/sendToDevice/m.room_key_request/{i}",
content={"messages": {user2: {"d2": {"idx": i}}}},
access_token=user1_tok,
)
self.assertEqual(chan.code, 200, chan.result)
# now sync: we should get two of the three (because burst_count=2)
channel = self.make_request(
"GET", self.sync_endpoint, access_token=user2_tok
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
for i in range(2):
self.assertEqual(
msgs[i],
{
"sender": user1,
"type": "m.room_key_request",
"content": {"idx": i},
},
)
sync_token = channel.json_body["next_batch"]
# ... time passes
self.reactor.advance(1)
# and we can send more messages
chan = self.make_request(
"PUT",
"/_matrix/client/r0/sendToDevice/m.room_key_request/3",
content={"messages": {user2: {"d2": {"idx": 3}}}},
access_token=user1_tok,
)
self.assertEqual(chan.code, 200, chan.result)
# ... which should arrive
channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={sync_token}",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 1)
self.assertEqual(
msgs[0],
{"sender": user1, "type": "m.room_key_request", "content": {"idx": 3}},
)
@override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}})
def test_remote_room_key_request(self) -> None:
"""m.room_key_request has special-casing; test from remote user"""
user2 = self.register_user("u2", "pass")
user2_tok = self.login("u2", "pass", "d2")
federation_registry = self.hs.get_federation_registry()
# send three messages
for i in range(3):
self.get_success(
federation_registry.on_edu(
EduTypes.DIRECT_TO_DEVICE,
"remote_server",
{
"sender": "@user:remote_server",
"type": "m.room_key_request",
"messages": {user2: {"d2": {"idx": i}}},
"message_id": f"{i}",
},
)
)
# now sync: we should get two of the three
channel = self.make_request(
"GET", self.sync_endpoint, access_token=user2_tok
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
for i in range(2):
self.assertEqual(
msgs[i],
{
"sender": "@user:remote_server",
"type": "m.room_key_request",
"content": {"idx": i},
},
)
sync_token = channel.json_body["next_batch"]
# ... time passes
self.reactor.advance(1)
# and we can send more messages
self.get_success(
federation_registry.on_edu(
EduTypes.DIRECT_TO_DEVICE,
"remote_server",
{
"sender": "@user:remote_server",
"type": "m.room_key_request",
"messages": {user2: {"d2": {"idx": 3}}},
"message_id": "3",
},
)
)
# ... which should arrive
channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={sync_token}",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 1)
self.assertEqual(
msgs[0],
{
"sender": "@user:remote_server",
"type": "m.room_key_request",
"content": {"idx": 3},
},
)
def test_limited_sync(self) -> None:
"""If a limited sync for to-devices happens the next /sync should respond immediately."""
self.register_user("u1", "pass")
user1_tok = self.login("u1", "pass", "d1")
user2 = self.register_user("u2", "pass")
user2_tok = self.login("u2", "pass", "d2")
# Do an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=user2_tok
)
self.assertEqual(channel.code, 200, channel.result)
sync_token = channel.json_body["next_batch"]
# Send 150 to-device messages. We limit to 100 in `/sync`
for i in range(150):
test_msg = {"foo": "bar"}
chan = self.make_request(
"PUT",
f"/_matrix/client/r0/sendToDevice/m.test/1234-{i}",
content={"messages": {user2: {"d2": test_msg}}},
access_token=user1_tok,
)
self.assertEqual(chan.code, 200, chan.result)
channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", [])
self.assertEqual(len(messages), 100)
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", [])
self.assertEqual(len(messages), 50)
class SendToDeviceTestCase(NotTested.SendToDeviceTestCaseBase):
# See SendToDeviceTestCaseBase above
pass pass

View file

@ -1,268 +0,0 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EduTypes
from synapse.rest import admin
from synapse.rest.client import login, sendtodevice, sync
from synapse.server import HomeServer
from synapse.util import Clock
from tests.unittest import HomeserverTestCase, override_config
class SendToDeviceTestCaseBase(HomeserverTestCase):
"""
Test `/sendToDevice` will deliver messages across to people receiving them over `/sync`.
In order to run the tests, inherit from this base-class with `HomeserverTestCase`, e.g.
`class SendToDeviceTestCase(SendToDeviceTestCase, HomeserverTestCase)`
"""
servlets = [
admin.register_servlets,
login.register_servlets,
sendtodevice.register_servlets,
sync.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.sync_endpoint = "/sync"
def test_user_to_user(self) -> None:
"""A to-device message from one user to another should get delivered"""
user1 = self.register_user("u1", "pass")
user1_tok = self.login("u1", "pass", "d1")
user2 = self.register_user("u2", "pass")
user2_tok = self.login("u2", "pass", "d2")
# send the message
test_msg = {"foo": "bar"}
chan = self.make_request(
"PUT",
"/_matrix/client/r0/sendToDevice/m.test/1234",
content={"messages": {user2: {"d2": test_msg}}},
access_token=user1_tok,
)
self.assertEqual(chan.code, 200, chan.result)
# check it appears
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
expected_result = {
"events": [
{
"sender": user1,
"type": "m.test",
"content": test_msg,
}
]
}
self.assertEqual(channel.json_body["to_device"], expected_result)
# it should re-appear if we do another sync because the to-device message is not
# deleted until we acknowledge it by sending a `?since=...` parameter in the
# next sync request corresponding to the `next_batch` value from the response.
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["to_device"], expected_result)
# it should *not* appear if we do an incremental sync
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
"GET", f"{self.sync_endpoint}?since={sync_token}", access_token=user2_tok
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), [])
@override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}})
def test_local_room_key_request(self) -> None:
"""m.room_key_request has special-casing; test from local user"""
user1 = self.register_user("u1", "pass")
user1_tok = self.login("u1", "pass", "d1")
user2 = self.register_user("u2", "pass")
user2_tok = self.login("u2", "pass", "d2")
# send three messages
for i in range(3):
chan = self.make_request(
"PUT",
f"/_matrix/client/r0/sendToDevice/m.room_key_request/{i}",
content={"messages": {user2: {"d2": {"idx": i}}}},
access_token=user1_tok,
)
self.assertEqual(chan.code, 200, chan.result)
# now sync: we should get two of the three (because burst_count=2)
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
for i in range(2):
self.assertEqual(
msgs[i],
{"sender": user1, "type": "m.room_key_request", "content": {"idx": i}},
)
sync_token = channel.json_body["next_batch"]
# ... time passes
self.reactor.advance(1)
# and we can send more messages
chan = self.make_request(
"PUT",
"/_matrix/client/r0/sendToDevice/m.room_key_request/3",
content={"messages": {user2: {"d2": {"idx": 3}}}},
access_token=user1_tok,
)
self.assertEqual(chan.code, 200, chan.result)
# ... which should arrive
channel = self.make_request(
"GET", f"{self.sync_endpoint}?since={sync_token}", access_token=user2_tok
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 1)
self.assertEqual(
msgs[0],
{"sender": user1, "type": "m.room_key_request", "content": {"idx": 3}},
)
@override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}})
def test_remote_room_key_request(self) -> None:
"""m.room_key_request has special-casing; test from remote user"""
user2 = self.register_user("u2", "pass")
user2_tok = self.login("u2", "pass", "d2")
federation_registry = self.hs.get_federation_registry()
# send three messages
for i in range(3):
self.get_success(
federation_registry.on_edu(
EduTypes.DIRECT_TO_DEVICE,
"remote_server",
{
"sender": "@user:remote_server",
"type": "m.room_key_request",
"messages": {user2: {"d2": {"idx": i}}},
"message_id": f"{i}",
},
)
)
# now sync: we should get two of the three
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
for i in range(2):
self.assertEqual(
msgs[i],
{
"sender": "@user:remote_server",
"type": "m.room_key_request",
"content": {"idx": i},
},
)
sync_token = channel.json_body["next_batch"]
# ... time passes
self.reactor.advance(1)
# and we can send more messages
self.get_success(
federation_registry.on_edu(
EduTypes.DIRECT_TO_DEVICE,
"remote_server",
{
"sender": "@user:remote_server",
"type": "m.room_key_request",
"messages": {user2: {"d2": {"idx": 3}}},
"message_id": "3",
},
)
)
# ... which should arrive
channel = self.make_request(
"GET", f"{self.sync_endpoint}?since={sync_token}", access_token=user2_tok
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 1)
self.assertEqual(
msgs[0],
{
"sender": "@user:remote_server",
"type": "m.room_key_request",
"content": {"idx": 3},
},
)
def test_limited_sync(self) -> None:
"""If a limited sync for to-devices happens the next /sync should respond immediately."""
self.register_user("u1", "pass")
user1_tok = self.login("u1", "pass", "d1")
user2 = self.register_user("u2", "pass")
user2_tok = self.login("u2", "pass", "d2")
# Do an initial sync
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
sync_token = channel.json_body["next_batch"]
# Send 150 to-device messages. We limit to 100 in `/sync`
for i in range(150):
test_msg = {"foo": "bar"}
chan = self.make_request(
"PUT",
f"/_matrix/client/r0/sendToDevice/m.test/1234-{i}",
content={"messages": {user2: {"d2": test_msg}}},
access_token=user1_tok,
)
self.assertEqual(chan.code, 200, chan.result)
channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", [])
self.assertEqual(len(messages), 100)
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", [])
self.assertEqual(len(messages), 50)

View file

@ -4,18 +4,18 @@ from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
# TODO: Uncomment this line when we have a pattern to share tests across files, see from tests.rest.client.test_sendtodevice import NotTested as SendToDeviceNotTested
# https://github.com/element-hq/synapse/pull/17167#discussion_r1594517041 from tests.rest.client.test_sync import NotTested as SyncNotTested
#
# from tests.rest.client.test_sync import DeviceListSyncTestCase
# from tests.rest.client.test_sync import DeviceOneTimeKeysSyncTestCase
# from tests.rest.client.test_sync import DeviceUnusedFallbackKeySyncTestCase
from tests.rest.client.test_sendtodevice_base import SendToDeviceTestCaseBase
from tests.unittest import HomeserverTestCase
# Test To-Device messages working correctly with the `/sync/e2ee` endpoint (`to_device`) class SlidingSyncE2eeSendToDeviceTestCase(
class SlidingSyncE2eeSendToDeviceTestCase(SendToDeviceTestCaseBase, HomeserverTestCase): SendToDeviceNotTested.SendToDeviceTestCaseBase
):
"""
Test To-Device messages working correctly with the `/sync/e2ee` endpoint
(`to_device`)
"""
def default_config(self) -> JsonDict: def default_config(self) -> JsonDict:
config = super().default_config() config = super().default_config()
# Enable sliding sync # Enable sliding sync
@ -23,7 +23,71 @@ class SlidingSyncE2eeSendToDeviceTestCase(SendToDeviceTestCaseBase, HomeserverTe
return config return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
# Use the Sliding Sync `/sync/e2ee` endpoint # Use the Sliding Sync `/sync/e2ee` endpoint
self.sync_endpoint = "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee" self.sync_endpoint = "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee"
# See SendToDeviceTestCaseBase for tests # See SendToDeviceTestCaseBase for tests
class SlidingSyncE2eeDeviceListSyncTestCase(SyncNotTested.DeviceListSyncTestCaseBase):
"""
Test device lists working correctly with the `/sync/e2ee` endpoint (`device_lists`)
"""
def default_config(self) -> JsonDict:
config = super().default_config()
# Enable sliding sync
config["experimental_features"] = {"msc3575_enabled": True}
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
# Use the Sliding Sync `/sync/e2ee` endpoint
self.sync_endpoint = "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee"
# See DeviceListSyncTestCaseBase for tests
class SlidingSyncE2eeDeviceOneTimeKeysSyncTestCase(
SyncNotTested.DeviceOneTimeKeysSyncTestCaseBase
):
"""
Test device one time keys working correctly with the `/sync/e2ee` endpoint
(`device_one_time_keys_count`)
"""
def default_config(self) -> JsonDict:
config = super().default_config()
# Enable sliding sync
config["experimental_features"] = {"msc3575_enabled": True}
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
# Use the Sliding Sync `/sync/e2ee` endpoint
self.sync_endpoint = "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee"
# See DeviceOneTimeKeysSyncTestCaseBase for tests
class SlidingSyncE2eeDeviceUnusedFallbackKeySyncTestCase(
SyncNotTested.DeviceUnusedFallbackKeySyncTestCaseBase
):
"""
Test device unused fallback key types working correctly with the `/sync/e2ee`
endpoint (`device_unused_fallback_key_types`)
"""
def default_config(self) -> JsonDict:
config = super().default_config()
# Enable sliding sync
config["experimental_features"] = {"msc3575_enabled": True}
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
# Use the Sliding Sync `/sync/e2ee` endpoint
self.sync_endpoint = "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee"
# See DeviceUnusedFallbackKeySyncTestCaseBase for tests

View file

@ -688,367 +688,396 @@ class SyncCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, 200, channel.json_body)
class DeviceListSyncTestCase(unittest.HomeserverTestCase): class NotTested:
"""Tests regarding device list (`device_lists`) changes.""" """
We nest the base test class to avoid the tests being run twice by the test runner
when we share/import these tests in other files. Without this, Twisted trial throws
a `KeyError` in the reporter when using multiple jobs (`poetry run trial --jobs=6`).
"""
servlets = [ class DeviceListSyncTestCaseBase(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets, """Tests regarding device list (`device_lists`) changes."""
login.register_servlets,
room.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: servlets = [
self.sync_endpoint = "/sync" synapse.rest.admin.register_servlets,
login.register_servlets,
room.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
def test_receiving_local_device_list_changes(self) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
"""Tests that a local users that share a room receive each other's device list self.sync_endpoint = "/sync"
changes.
"""
# Register two users
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
bob_user_id = self.register_user("bob", "ponyponypony") def test_receiving_local_device_list_changes(self) -> None:
bob_access_token = self.login(bob_user_id, "ponyponypony") """Tests that a local users that share a room receive each other's device list
changes.
"""
# Register two users
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# Create a room for them to coexist peacefully in bob_user_id = self.register_user("bob", "ponyponypony")
new_room_id = self.helper.create_room_as( bob_access_token = self.login(bob_user_id, "ponyponypony")
alice_user_id, is_public=True, tok=alice_access_token
)
self.assertIsNotNone(new_room_id)
# Have Bob join the room # Create a room for them to coexist peacefully in
self.helper.invite( new_room_id = self.helper.create_room_as(
new_room_id, alice_user_id, bob_user_id, tok=alice_access_token alice_user_id, is_public=True, tok=alice_access_token
) )
self.helper.join(new_room_id, bob_user_id, tok=bob_access_token) self.assertIsNotNone(new_room_id)
# Now have Bob initiate an initial sync (in order to get a since token) # Have Bob join the room
channel = self.make_request( self.helper.invite(
"GET", new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
self.sync_endpoint, )
access_token=bob_access_token, self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
)
self.assertEqual(channel.code, 200, channel.json_body)
next_batch_token = channel.json_body["next_batch"]
# ...and then an incremental sync. This should block until the sync stream is woken up, # Now have Bob initiate an initial sync (in order to get a since token)
# which we hope will happen as a result of Alice updating their device list. channel = self.make_request(
bob_sync_channel = self.make_request( "GET",
"GET", self.sync_endpoint,
f"{self.sync_endpoint}?since={next_batch_token}&timeout=30000", access_token=bob_access_token,
access_token=bob_access_token, )
# Start the request, then continue on. self.assertEqual(channel.code, 200, channel.json_body)
await_result=False, next_batch_token = channel.json_body["next_batch"]
)
# Have alice update their device list # ...and then an incremental sync. This should block until the sync stream is woken up,
channel = self.make_request( # which we hope will happen as a result of Alice updating their device list.
"PUT", bob_sync_channel = self.make_request(
f"/devices/{test_device_id}", "GET",
{ f"{self.sync_endpoint}?since={next_batch_token}&timeout=30000",
"display_name": "New Device Name", access_token=bob_access_token,
}, # Start the request, then continue on.
access_token=alice_access_token, await_result=False,
) )
self.assertEqual(channel.code, 200, channel.json_body)
# Check that bob's incremental sync contains the updated device list. # Have alice update their device list
# If not, the client would only receive the device list update on the channel = self.make_request(
# *next* sync. "PUT",
bob_sync_channel.await_result() f"/devices/{test_device_id}",
self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) {
"display_name": "New Device Name",
},
access_token=alice_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( # Check that bob's incremental sync contains the updated device list.
"changed", [] # If not, the client would only receive the device list update on the
) # *next* sync.
self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body) bob_sync_channel.await_result()
self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
def test_not_receiving_local_device_list_changes(self) -> None: changed_device_lists = bob_sync_channel.json_body.get(
"""Tests a local users DO NOT receive device updates from each other if they do not "device_lists", {}
share a room. ).get("changed", [])
""" self.assertIn(
# Register two users alice_user_id, changed_device_lists, bob_sync_channel.json_body
test_device_id = "TESTDEVICE" )
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
bob_user_id = self.register_user("bob", "ponyponypony") def test_not_receiving_local_device_list_changes(self) -> None:
bob_access_token = self.login(bob_user_id, "ponyponypony") """Tests a local users DO NOT receive device updates from each other if they do not
share a room.
"""
# Register two users
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# These users do not share a room. They are lonely. bob_user_id = self.register_user("bob", "ponyponypony")
bob_access_token = self.login(bob_user_id, "ponyponypony")
# Have Bob initiate an initial sync (in order to get a since token) # These users do not share a room. They are lonely.
channel = self.make_request(
"GET",
self.sync_endpoint,
access_token=bob_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
next_batch_token = channel.json_body["next_batch"]
# ...and then an incremental sync. This should block until the sync stream is woken up, # Have Bob initiate an initial sync (in order to get a since token)
# which we hope will happen as a result of Alice updating their device list. channel = self.make_request(
bob_sync_channel = self.make_request( "GET",
"GET", self.sync_endpoint,
f"{self.sync_endpoint}?since={next_batch_token}&timeout=1000", access_token=bob_access_token,
access_token=bob_access_token, )
# Start the request, then continue on. self.assertEqual(channel.code, 200, channel.json_body)
await_result=False, next_batch_token = channel.json_body["next_batch"]
)
# Have alice update their device list # ...and then an incremental sync. This should block until the sync stream is woken up,
channel = self.make_request( # which we hope will happen as a result of Alice updating their device list.
"PUT", bob_sync_channel = self.make_request(
f"/devices/{test_device_id}", "GET",
{ f"{self.sync_endpoint}?since={next_batch_token}&timeout=1000",
"display_name": "New Device Name", access_token=bob_access_token,
}, # Start the request, then continue on.
access_token=alice_access_token, await_result=False,
) )
self.assertEqual(channel.code, 200, channel.json_body)
# Check that bob's incremental sync does not contain the updated device list. # Have alice update their device list
bob_sync_channel.await_result() channel = self.make_request(
self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) "PUT",
f"/devices/{test_device_id}",
{
"display_name": "New Device Name",
},
access_token=alice_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( # Check that bob's incremental sync does not contain the updated device list.
"changed", [] bob_sync_channel.await_result()
) self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
self.assertNotIn(
alice_user_id, changed_device_lists, bob_sync_channel.json_body
)
def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None: changed_device_lists = bob_sync_channel.json_body.get(
"""Tests that a user with no rooms still receives their own device list updates""" "device_lists", {}
test_device_id = "TESTDEVICE" ).get("changed", [])
self.assertNotIn(
alice_user_id, changed_device_lists, bob_sync_channel.json_body
)
# Register a user and login, creating a device def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None:
alice_user_id = self.register_user("alice", "correcthorse") """Tests that a user with no rooms still receives their own device list updates"""
alice_access_token = self.login( test_device_id = "TESTDEVICE"
alice_user_id, "correcthorse", device_id=test_device_id
)
# Request an initial sync # Register a user and login, creating a device
channel = self.make_request( alice_user_id = self.register_user("alice", "correcthorse")
"GET", self.sync_endpoint, access_token=alice_access_token alice_access_token = self.login(
) alice_user_id, "correcthorse", device_id=test_device_id
self.assertEqual(channel.code, 200, channel.json_body) )
next_batch = channel.json_body["next_batch"]
# Now, make an incremental sync request. # Request an initial sync
# It won't return until something has happened channel = self.make_request(
incremental_sync_channel = self.make_request( "GET", self.sync_endpoint, access_token=alice_access_token
"GET", )
f"{self.sync_endpoint}?since={next_batch}&timeout=30000", self.assertEqual(channel.code, 200, channel.json_body)
access_token=alice_access_token, next_batch = channel.json_body["next_batch"]
await_result=False,
)
# Change our device's display name # Now, make an incremental sync request.
channel = self.make_request( # It won't return until something has happened
"PUT", incremental_sync_channel = self.make_request(
f"devices/{test_device_id}", "GET",
{ f"{self.sync_endpoint}?since={next_batch}&timeout=30000",
"display_name": "freeze ray", access_token=alice_access_token,
}, await_result=False,
access_token=alice_access_token, )
)
self.assertEqual(channel.code, 200, channel.json_body)
# The sync should now have returned # Change our device's display name
incremental_sync_channel.await_result(timeout_ms=20000) channel = self.make_request(
self.assertEqual(incremental_sync_channel.code, 200, channel.json_body) "PUT",
f"devices/{test_device_id}",
{
"display_name": "freeze ray",
},
access_token=alice_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
# We should have received notification that the (user's) device has changed # The sync should now have returned
device_list_changes = incremental_sync_channel.json_body.get( incremental_sync_channel.await_result(timeout_ms=20000)
"device_lists", {} self.assertEqual(incremental_sync_channel.code, 200, channel.json_body)
).get("changed", [])
self.assertIn( # We should have received notification that the (user's) device has changed
alice_user_id, device_list_changes, incremental_sync_channel.json_body device_list_changes = incremental_sync_channel.json_body.get(
) "device_lists", {}
).get("changed", [])
self.assertIn(
alice_user_id, device_list_changes, incremental_sync_channel.json_body
)
class DeviceOneTimeKeysSyncTestCase(unittest.HomeserverTestCase): class DeviceOneTimeKeysSyncTestCaseBase(unittest.HomeserverTestCase):
"""Tests regarding device one time keys (`device_one_time_keys_count`) changes.""" """Tests regarding device one time keys (`device_one_time_keys_count`) changes."""
servlets = [ servlets = [
synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets,
login.register_servlets, login.register_servlets,
sync.register_servlets, sync.register_servlets,
devices.register_servlets, devices.register_servlets,
] ]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.sync_endpoint = "/sync" self.sync_endpoint = "/sync"
self.e2e_keys_handler = hs.get_e2e_keys_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
def test_no_device_one_time_keys(self) -> None: def test_no_device_one_time_keys(self) -> None:
""" """
Tests when no one time keys set, it still has the default `signed_curve25519` in Tests when no one time keys set, it still has the default `signed_curve25519` in
`device_one_time_keys_count` `device_one_time_keys_count`
""" """
test_device_id = "TESTDEVICE" test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse") alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login( alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id alice_user_id, "correcthorse", device_id=test_device_id
) )
# Request an initial sync # Request an initial sync
channel = self.make_request( channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token "GET", self.sync_endpoint, access_token=alice_access_token
) )
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, 200, channel.json_body)
# Check for those one time key counts # Check for those one time key counts
self.assertDictEqual( self.assertDictEqual(
channel.json_body["device_one_time_keys_count"], channel.json_body["device_one_time_keys_count"],
# Note that "signed_curve25519" is always returned in key count responses
# regardless of whether we uploaded any keys for it. This is necessary until
# https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
{"signed_curve25519": 0},
channel.json_body["device_one_time_keys_count"],
)
def test_returns_device_one_time_keys(self) -> None:
"""
Tests that one time keys for the device/user are counted correctly in the `/sync`
response
"""
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# Upload one time keys for the user/device
keys: JsonDict = {
"alg1:k1": "key1",
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
"alg2:k3": {"key": "key3"},
}
res = self.get_success(
self.e2e_keys_handler.upload_keys_for_user(
alice_user_id, test_device_id, {"one_time_keys": keys}
)
)
# Note that "signed_curve25519" is always returned in key count responses # Note that "signed_curve25519" is always returned in key count responses
# regardless of whether we uploaded any keys for it. This is necessary until # regardless of whether we uploaded any keys for it. This is necessary until
# https://github.com/matrix-org/matrix-doc/issues/3298 is fixed. # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
{"signed_curve25519": 0}, self.assertDictEqual(
channel.json_body["device_one_time_keys_count"], res,
) {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}},
def test_returns_device_one_time_keys(self) -> None:
"""
Tests that one time keys for the device/user are counted correctly in the `/sync`
response
"""
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# Upload one time keys for the user/device
keys: JsonDict = {
"alg1:k1": "key1",
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
"alg2:k3": {"key": "key3"},
}
res = self.get_success(
self.e2e_keys_handler.upload_keys_for_user(
alice_user_id, test_device_id, {"one_time_keys": keys}
) )
)
# Note that "signed_curve25519" is always returned in key count responses
# regardless of whether we uploaded any keys for it. This is necessary until
# https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
self.assertDictEqual(
res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
)
# Request an initial sync # Request an initial sync
channel = self.make_request( channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token "GET", self.sync_endpoint, access_token=alice_access_token
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for those one time key counts
self.assertDictEqual(
channel.json_body["device_one_time_keys_count"],
{"alg1": 1, "alg2": 2, "signed_curve25519": 0},
channel.json_body["device_one_time_keys_count"],
)
class DeviceUnusedFallbackKeySyncTestCase(unittest.HomeserverTestCase):
"""Tests regarding device one time keys (`device_unused_fallback_key_types`) changes."""
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.sync_endpoint = "/sync"
self.store = self.hs.get_datastores().main
self.e2e_keys_handler = hs.get_e2e_keys_handler()
def test_no_device_unused_fallback_key(self) -> None:
"""
Test when no unused fallback key is set, it just returns an empty list. The MSC
says "The device_unused_fallback_key_types parameter must be present if the
server supports fallback keys.",
https://github.com/matrix-org/matrix-spec-proposals/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
"""
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for those one time key counts
self.assertListEqual(
channel.json_body["device_unused_fallback_key_types"],
[],
channel.json_body["device_unused_fallback_key_types"],
)
def test_returns_device_one_time_keys(self) -> None:
"""
Tests that device unused fallback key type is returned correctly in the `/sync`
"""
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# We shouldn't have any unused fallback keys yet
res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
)
self.assertEqual(res, [])
# Upload a fallback key for the user/device
fallback_key = {"alg1:k1": "fallback_key1"}
self.get_success(
self.e2e_keys_handler.upload_keys_for_user(
alice_user_id,
test_device_id,
{"fallback_keys": fallback_key},
) )
) self.assertEqual(channel.code, 200, channel.json_body)
# We should now have an unused alg1 key
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
)
self.assertEqual(fallback_res, ["alg1"], fallback_res)
# Request an initial sync # Check for those one time key counts
channel = self.make_request( self.assertDictEqual(
"GET", self.sync_endpoint, access_token=alice_access_token channel.json_body["device_one_time_keys_count"],
) {"alg1": 1, "alg2": 2, "signed_curve25519": 0},
self.assertEqual(channel.code, 200, channel.json_body) channel.json_body["device_one_time_keys_count"],
)
# Check for the unused fallback key types class DeviceUnusedFallbackKeySyncTestCaseBase(unittest.HomeserverTestCase):
self.assertListEqual( """Tests regarding device one time keys (`device_unused_fallback_key_types`) changes."""
channel.json_body["device_unused_fallback_key_types"],
["alg1"], servlets = [
channel.json_body["device_unused_fallback_key_types"], synapse.rest.admin.register_servlets,
) login.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.sync_endpoint = "/sync"
self.store = self.hs.get_datastores().main
self.e2e_keys_handler = hs.get_e2e_keys_handler()
def test_no_device_unused_fallback_key(self) -> None:
"""
Test when no unused fallback key is set, it just returns an empty list. The MSC
says "The device_unused_fallback_key_types parameter must be present if the
server supports fallback keys.",
https://github.com/matrix-org/matrix-spec-proposals/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
"""
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for those one time key counts
self.assertListEqual(
channel.json_body["device_unused_fallback_key_types"],
[],
channel.json_body["device_unused_fallback_key_types"],
)
def test_returns_device_one_time_keys(self) -> None:
"""
Tests that device unused fallback key type is returned correctly in the `/sync`
"""
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# We shouldn't have any unused fallback keys yet
res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(
alice_user_id, test_device_id
)
)
self.assertEqual(res, [])
# Upload a fallback key for the user/device
fallback_key = {"alg1:k1": "fallback_key1"}
self.get_success(
self.e2e_keys_handler.upload_keys_for_user(
alice_user_id,
test_device_id,
{"fallback_keys": fallback_key},
)
)
# We should now have an unused alg1 key
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(
alice_user_id, test_device_id
)
)
self.assertEqual(fallback_res, ["alg1"], fallback_res)
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for the unused fallback key types
self.assertListEqual(
channel.json_body["device_unused_fallback_key_types"],
["alg1"],
channel.json_body["device_unused_fallback_key_types"],
)
class DeviceListSyncTestCase(NotTested.DeviceListSyncTestCaseBase):
# See DeviceListSyncTestCaseBase above
pass
class DeviceOneTimeKeysSyncTestCase(NotTested.DeviceOneTimeKeysSyncTestCaseBase):
# See DeviceOneTimeKeysSyncTestCaseBase above
pass
class DeviceUnusedFallbackKeySyncTestCase(
NotTested.DeviceUnusedFallbackKeySyncTestCaseBase
):
# See DeviceUnusedFallbackKeySyncTestCaseBase above
pass
class ExcludeRoomTestCase(unittest.HomeserverTestCase): class ExcludeRoomTestCase(unittest.HomeserverTestCase):