Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2024-06-19 10:34:40 +01:00
commit 4f308ea362
14 changed files with 438 additions and 116 deletions

1
changelog.d/17283.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a long-standing bug where an invalid 'from' parameter to [`/notifications`](https://spec.matrix.org/v1.10/client-server-api/#get_matrixclientv3notifications) would result in an Internal Server Error.

1
changelog.d/17291.misc Normal file
View file

@ -0,0 +1 @@
Do not block event sending/receiving while calulating large event auth chains.

View file

@ -0,0 +1,2 @@
`register_new_matrix_user` now supports a --password-file flag, which
is useful for scripting.

6
debian/changelog vendored
View file

@ -1,3 +1,9 @@
matrix-synapse-py3 (1.109.0+nmu1) UNRELEASED; urgency=medium
* `register_new_matrix_user` now supports a --password-file flag.
-- Synapse Packaging team <packages@matrix.org> Tue, 18 Jun 2024 13:29:36 +0100
matrix-synapse-py3 (1.109.0) stable; urgency=medium matrix-synapse-py3 (1.109.0) stable; urgency=medium
* New synapse release 1.109.0. * New synapse release 1.109.0.

View file

@ -31,8 +31,12 @@ A sample YAML file accepted by `register_new_matrix_user` is described below:
Local part of the new user. Will prompt if omitted. Local part of the new user. Will prompt if omitted.
* `-p`, `--password`: * `-p`, `--password`:
New password for user. Will prompt if omitted. Supplying the password New password for user. Will prompt if this option and `--password-file` are omitted.
on the command line is not recommended. Use the STDIN instead. Supplying the password on the command line is not recommended.
* `--password-file`:
File containing the new password for user. If set, overrides `--password`.
This is a more secure alternative to specifying the password on the command line.
* `-a`, `--admin`: * `-a`, `--admin`:
Register new user as an admin. Will prompt if omitted. Register new user as an admin. Will prompt if omitted.

View file

@ -173,11 +173,18 @@ def main() -> None:
default=None, default=None,
help="Local part of the new user. Will prompt if omitted.", help="Local part of the new user. Will prompt if omitted.",
) )
parser.add_argument( password_group = parser.add_mutually_exclusive_group()
password_group.add_argument(
"-p", "-p",
"--password", "--password",
default=None, default=None,
help="New password for user. Will prompt if omitted.", help="New password for user. Will prompt for a password if "
"this flag and `--password-file` are both omitted.",
)
password_group.add_argument(
"--password-file",
default=None,
help="File containing the new password for user. If set, will override `--password`.",
) )
parser.add_argument( parser.add_argument(
"-t", "-t",
@ -247,6 +254,11 @@ def main() -> None:
print(_NO_SHARED_SECRET_OPTS_ERROR, file=sys.stderr) print(_NO_SHARED_SECRET_OPTS_ERROR, file=sys.stderr)
sys.exit(1) sys.exit(1)
if args.password_file:
password = _read_file(args.password_file, "password-file").strip()
else:
password = args.password
if args.server_url: if args.server_url:
server_url = args.server_url server_url = args.server_url
elif config is not None: elif config is not None:
@ -269,9 +281,7 @@ def main() -> None:
if args.admin or args.no_admin: if args.admin or args.no_admin:
admin = args.admin admin = args.admin
register_new_user( register_new_user(args.user, password, server_url, secret, admin, args.user_type)
args.user, args.password, server_url, secret, admin, args.user_type
)
def _read_file(file_path: Any, config_path: str) -> str: def _read_file(file_path: Any, config_path: str) -> str:

View file

@ -32,6 +32,7 @@ from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.types import JsonDict from synapse.types import JsonDict
from ...api.errors import SynapseError
from ._base import client_patterns from ._base import client_patterns
if TYPE_CHECKING: if TYPE_CHECKING:
@ -56,7 +57,22 @@ class NotificationsServlet(RestServlet):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
from_token = parse_string(request, "from", required=False) # While this is intended to be "string" to clients, the 'from' token
# is actually based on a numeric ID. So it must parse to an int.
from_token_str = parse_string(request, "from", required=False)
if from_token_str is not None:
# Parse to an integer.
try:
from_token = int(from_token_str)
except ValueError:
# If it doesn't parse to an integer, then this cannot possibly be a valid
# pagination token, as we only hand out integers.
raise SynapseError(
400, 'Query parameter "from" contains unrecognised token'
)
else:
from_token = None
limit = parse_integer(request, "limit", default=50) limit = parse_integer(request, "limit", default=50)
only = parse_string(request, "only", required=False) only = parse_string(request, "only", required=False)

View file

@ -617,6 +617,17 @@ class EventsPersistenceStorageController:
room_id, chunk room_id, chunk
) )
with Measure(self._clock, "calculate_chain_cover_index_for_events"):
# We now calculate chain ID/sequence numbers for any state events we're
# persisting. We ignore out of band memberships as we're not in the room
# and won't have their auth chain (we'll fix it up later if we join the
# room).
#
# See: docs/auth_chain_difference_algorithm.md
new_event_links = await self.persist_events_store.calculate_chain_cover_index_for_events(
room_id, [e for e, _ in chunk]
)
await self.persist_events_store._persist_events_and_state_updates( await self.persist_events_store._persist_events_and_state_updates(
room_id, room_id,
chunk, chunk,
@ -624,6 +635,7 @@ class EventsPersistenceStorageController:
new_forward_extremities=new_forward_extremities, new_forward_extremities=new_forward_extremities,
use_negative_stream_ordering=backfilled, use_negative_stream_ordering=backfilled,
inhibit_local_membership_updates=backfilled, inhibit_local_membership_updates=backfilled,
new_event_links=new_event_links,
) )
return replaced_events return replaced_events

View file

@ -1829,7 +1829,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
async def get_push_actions_for_user( async def get_push_actions_for_user(
self, self,
user_id: str, user_id: str,
before: Optional[str] = None, before: Optional[int] = None,
limit: int = 50, limit: int = 50,
only_highlight: bool = False, only_highlight: bool = False,
) -> List[UserPushAction]: ) -> List[UserPushAction]:

View file

@ -34,7 +34,6 @@ from typing import (
Optional, Optional,
Set, Set,
Tuple, Tuple,
Union,
cast, cast,
) )
@ -100,6 +99,23 @@ class DeltaState:
return not self.to_delete and not self.to_insert and not self.no_longer_in_room return not self.to_delete and not self.to_insert and not self.no_longer_in_room
@attr.s(slots=True, auto_attribs=True)
class NewEventChainLinks:
"""Information about new auth chain links that need to be added to the DB.
Attributes:
chain_id, sequence_number: the IDs corresponding to the event being
inserted, and the starting point of the links
links: Lists the links that need to be added, 2-tuple of the chain
ID/sequence number of the end point of the link.
"""
chain_id: int
sequence_number: int
links: List[Tuple[int, int]] = attr.Factory(list)
class PersistEventsStore: class PersistEventsStore:
"""Contains all the functions for writing events to the database. """Contains all the functions for writing events to the database.
@ -148,6 +164,7 @@ class PersistEventsStore:
*, *,
state_delta_for_room: Optional[DeltaState], state_delta_for_room: Optional[DeltaState],
new_forward_extremities: Optional[Set[str]], new_forward_extremities: Optional[Set[str]],
new_event_links: Dict[str, NewEventChainLinks],
use_negative_stream_ordering: bool = False, use_negative_stream_ordering: bool = False,
inhibit_local_membership_updates: bool = False, inhibit_local_membership_updates: bool = False,
) -> None: ) -> None:
@ -217,6 +234,7 @@ class PersistEventsStore:
inhibit_local_membership_updates=inhibit_local_membership_updates, inhibit_local_membership_updates=inhibit_local_membership_updates,
state_delta_for_room=state_delta_for_room, state_delta_for_room=state_delta_for_room,
new_forward_extremities=new_forward_extremities, new_forward_extremities=new_forward_extremities,
new_event_links=new_event_links,
) )
persist_event_counter.inc(len(events_and_contexts)) persist_event_counter.inc(len(events_and_contexts))
@ -243,6 +261,87 @@ class PersistEventsStore:
(room_id,), frozenset(new_forward_extremities) (room_id,), frozenset(new_forward_extremities)
) )
async def calculate_chain_cover_index_for_events(
self, room_id: str, events: Collection[EventBase]
) -> Dict[str, NewEventChainLinks]:
# Filter to state events, and ensure there are no duplicates.
state_events = []
seen_events = set()
for event in events:
if not event.is_state() or event.event_id in seen_events:
continue
state_events.append(event)
seen_events.add(event.event_id)
if not state_events:
return {}
return await self.db_pool.runInteraction(
"_calculate_chain_cover_index_for_events",
self.calculate_chain_cover_index_for_events_txn,
room_id,
state_events,
)
def calculate_chain_cover_index_for_events_txn(
self, txn: LoggingTransaction, room_id: str, state_events: Collection[EventBase]
) -> Dict[str, NewEventChainLinks]:
# We now calculate chain ID/sequence numbers for any state events we're
# persisting. We ignore out of band memberships as we're not in the room
# and won't have their auth chain (we'll fix it up later if we join the
# room).
#
# See: docs/auth_chain_difference_algorithm.md
# We ignore legacy rooms that we aren't filling the chain cover index
# for.
row = self.db_pool.simple_select_one_txn(
txn,
table="rooms",
keyvalues={"room_id": room_id},
retcols=("room_id", "has_auth_chain_index"),
allow_none=True,
)
if row is None:
return {}
# Filter out already persisted events.
rows = self.db_pool.simple_select_many_txn(
txn,
table="events",
column="event_id",
iterable=[e.event_id for e in state_events],
keyvalues={},
retcols=("event_id",),
)
already_persisted_events = {event_id for event_id, in rows}
state_events = [
event
for event in state_events
if event.event_id in already_persisted_events
]
if not state_events:
return {}
# We need to know the type/state_key and auth events of the events we're
# calculating chain IDs for. We don't rely on having the full Event
# instances as we'll potentially be pulling more events from the DB and
# we don't need the overhead of fetching/parsing the full event JSON.
event_to_types = {e.event_id: (e.type, e.state_key) for e in state_events}
event_to_auth_chain = {e.event_id: e.auth_event_ids() for e in state_events}
event_to_room_id = {e.event_id: e.room_id for e in state_events}
return self._calculate_chain_cover_index(
txn,
self.db_pool,
self.store.event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
)
async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]: async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
"""Filter the supplied list of event_ids to get those which are prev_events of """Filter the supplied list of event_ids to get those which are prev_events of
existing (non-outlier/rejected) events. existing (non-outlier/rejected) events.
@ -358,6 +457,7 @@ class PersistEventsStore:
inhibit_local_membership_updates: bool, inhibit_local_membership_updates: bool,
state_delta_for_room: Optional[DeltaState], state_delta_for_room: Optional[DeltaState],
new_forward_extremities: Optional[Set[str]], new_forward_extremities: Optional[Set[str]],
new_event_links: Dict[str, NewEventChainLinks],
) -> None: ) -> None:
"""Insert some number of room events into the necessary database tables. """Insert some number of room events into the necessary database tables.
@ -466,7 +566,9 @@ class PersistEventsStore:
# Insert into event_to_state_groups. # Insert into event_to_state_groups.
self._store_event_state_mappings_txn(txn, events_and_contexts) self._store_event_state_mappings_txn(txn, events_and_contexts)
self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts]) self._persist_event_auth_chain_txn(
txn, [e for e, _ in events_and_contexts], new_event_links
)
# _store_rejected_events_txn filters out any events which were # _store_rejected_events_txn filters out any events which were
# rejected, and returns the filtered list. # rejected, and returns the filtered list.
@ -496,6 +598,7 @@ class PersistEventsStore:
self, self,
txn: LoggingTransaction, txn: LoggingTransaction,
events: List[EventBase], events: List[EventBase],
new_event_links: Dict[str, NewEventChainLinks],
) -> None: ) -> None:
# We only care about state events, so this if there are no state events. # We only care about state events, so this if there are no state events.
if not any(e.is_state() for e in events): if not any(e.is_state() for e in events):
@ -519,59 +622,8 @@ class PersistEventsStore:
], ],
) )
# We now calculate chain ID/sequence numbers for any state events we're if new_event_links:
# persisting. We ignore out of band memberships as we're not in the room self._persist_chain_cover_index(txn, self.db_pool, new_event_links)
# and won't have their auth chain (we'll fix it up later if we join the
# room).
#
# See: docs/auth_chain_difference_algorithm.md
# We ignore legacy rooms that we aren't filling the chain cover index
# for.
rows = cast(
List[Tuple[str, Optional[Union[int, bool]]]],
self.db_pool.simple_select_many_txn(
txn,
table="rooms",
column="room_id",
iterable={event.room_id for event in events if event.is_state()},
keyvalues={},
retcols=("room_id", "has_auth_chain_index"),
),
)
rooms_using_chain_index = {
room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index
}
state_events = {
event.event_id: event
for event in events
if event.is_state() and event.room_id in rooms_using_chain_index
}
if not state_events:
return
# We need to know the type/state_key and auth events of the events we're
# calculating chain IDs for. We don't rely on having the full Event
# instances as we'll potentially be pulling more events from the DB and
# we don't need the overhead of fetching/parsing the full event JSON.
event_to_types = {
e.event_id: (e.type, e.state_key) for e in state_events.values()
}
event_to_auth_chain = {
e.event_id: e.auth_event_ids() for e in state_events.values()
}
event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
self._add_chain_cover_index(
txn,
self.db_pool,
self.store.event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
)
@classmethod @classmethod
def _add_chain_cover_index( def _add_chain_cover_index(
@ -583,6 +635,35 @@ class PersistEventsStore:
event_to_types: Dict[str, Tuple[str, str]], event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, StrCollection], event_to_auth_chain: Dict[str, StrCollection],
) -> None: ) -> None:
"""Calculate and persist the chain cover index for the given events.
Args:
event_to_room_id: Event ID to the room ID of the event
event_to_types: Event ID to type and state_key of the event
event_to_auth_chain: Event ID to list of auth event IDs of the
event (events with no auth events can be excluded).
"""
new_event_links = cls._calculate_chain_cover_index(
txn,
db_pool,
event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
)
cls._persist_chain_cover_index(txn, db_pool, new_event_links)
@classmethod
def _calculate_chain_cover_index(
cls,
txn: LoggingTransaction,
db_pool: DatabasePool,
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, StrCollection],
) -> Dict[str, NewEventChainLinks]:
"""Calculate the chain cover index for the given events. """Calculate the chain cover index for the given events.
Args: Args:
@ -590,6 +671,10 @@ class PersistEventsStore:
event_to_types: Event ID to type and state_key of the event event_to_types: Event ID to type and state_key of the event
event_to_auth_chain: Event ID to list of auth event IDs of the event_to_auth_chain: Event ID to list of auth event IDs of the
event (events with no auth events can be excluded). event (events with no auth events can be excluded).
Returns:
A mapping with any new auth chain links we need to add, keyed by
event ID.
""" """
# Map from event ID to chain ID/sequence number. # Map from event ID to chain ID/sequence number.
@ -708,11 +793,11 @@ class PersistEventsStore:
room_id = event_to_room_id.get(event_id) room_id = event_to_room_id.get(event_id)
if room_id: if room_id:
e_type, state_key = event_to_types[event_id] e_type, state_key = event_to_types[event_id]
db_pool.simple_insert_txn( db_pool.simple_upsert_txn(
txn, txn,
table="event_auth_chain_to_calculate", table="event_auth_chain_to_calculate",
keyvalues={"event_id": event_id},
values={ values={
"event_id": event_id,
"room_id": room_id, "room_id": room_id,
"type": e_type, "type": e_type,
"state_key": state_key, "state_key": state_key,
@ -724,7 +809,7 @@ class PersistEventsStore:
break break
if not events_to_calc_chain_id_for: if not events_to_calc_chain_id_for:
return return {}
# Allocate chain ID/sequence numbers to each new event. # Allocate chain ID/sequence numbers to each new event.
new_chain_tuples = cls._allocate_chain_ids( new_chain_tuples = cls._allocate_chain_ids(
@ -739,23 +824,10 @@ class PersistEventsStore:
) )
chain_map.update(new_chain_tuples) chain_map.update(new_chain_tuples)
db_pool.simple_insert_many_txn( to_return = {
txn, event_id: NewEventChainLinks(chain_id, sequence_number)
table="event_auth_chains", for event_id, (chain_id, sequence_number) in new_chain_tuples.items()
keys=("event_id", "chain_id", "sequence_number"), }
values=[
(event_id, c_id, seq)
for event_id, (c_id, seq) in new_chain_tuples.items()
],
)
db_pool.simple_delete_many_txn(
txn,
table="event_auth_chain_to_calculate",
keyvalues={},
column="event_id",
values=new_chain_tuples,
)
# Now we need to calculate any new links between chains caused by # Now we need to calculate any new links between chains caused by
# the new events. # the new events.
@ -825,10 +897,38 @@ class PersistEventsStore:
auth_chain_id, auth_sequence_number = chain_map[auth_id] auth_chain_id, auth_sequence_number = chain_map[auth_id]
# Step 2a, add link between the event and auth event # Step 2a, add link between the event and auth event
to_return[event_id].links.append((auth_chain_id, auth_sequence_number))
chain_links.add_link( chain_links.add_link(
(chain_id, sequence_number), (auth_chain_id, auth_sequence_number) (chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
) )
return to_return
@classmethod
def _persist_chain_cover_index(
cls,
txn: LoggingTransaction,
db_pool: DatabasePool,
new_event_links: Dict[str, NewEventChainLinks],
) -> None:
db_pool.simple_insert_many_txn(
txn,
table="event_auth_chains",
keys=("event_id", "chain_id", "sequence_number"),
values=[
(event_id, new_links.chain_id, new_links.sequence_number)
for event_id, new_links in new_event_links.items()
],
)
db_pool.simple_delete_many_txn(
txn,
table="event_auth_chain_to_calculate",
keyvalues={},
column="event_id",
values=new_event_links,
)
db_pool.simple_insert_many_txn( db_pool.simple_insert_many_txn(
txn, txn,
table="event_auth_chain_links", table="event_auth_chain_links",
@ -838,7 +938,16 @@ class PersistEventsStore:
"target_chain_id", "target_chain_id",
"target_sequence_number", "target_sequence_number",
), ),
values=list(chain_links.get_additions()), values=[
(
new_links.chain_id,
new_links.sequence_number,
target_chain_id,
target_sequence_number,
)
for new_links in new_event_links.values()
for (target_chain_id, target_sequence_number) in new_links.links
],
) )
@staticmethod @staticmethod

View file

@ -688,7 +688,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/notifications?from=", "/notifications",
access_token=tok, access_token=tok,
) )
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)

View file

@ -18,6 +18,7 @@
# [This file includes modifications made by New Vector Limited] # [This file includes modifications made by New Vector Limited]
# #
# #
from typing import List, Optional, Tuple
from unittest.mock import AsyncMock, Mock from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -48,6 +49,14 @@ class HTTPPusherTests(HomeserverTestCase):
self.sync_handler = homeserver.get_sync_handler() self.sync_handler = homeserver.get_sync_handler()
self.auth_handler = homeserver.get_auth_handler() self.auth_handler = homeserver.get_auth_handler()
self.user_id = self.register_user("user", "pass")
self.access_token = self.login("user", "pass")
self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass")
# Create a room
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation. # Mock out the calls over federation.
fed_transport_client = Mock(spec=["send_transaction"]) fed_transport_client = Mock(spec=["send_transaction"])
@ -61,32 +70,22 @@ class HTTPPusherTests(HomeserverTestCase):
""" """
Local users will get notified for invites Local users will get notified for invites
""" """
user_id = self.register_user("user", "pass")
access_token = self.login("user", "pass")
other_user_id = self.register_user("otheruser", "pass")
other_access_token = self.login("otheruser", "pass")
# Create a room
room = self.helper.create_room_as(user_id, tok=access_token)
# Check we start with no pushes # Check we start with no pushes
channel = self.make_request( self._request_notifications(from_token=None, limit=1, expected_count=0)
"GET",
"/notifications",
access_token=other_access_token,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(len(channel.json_body["notifications"]), 0, channel.json_body)
# Send an invite # Send an invite
self.helper.invite(room=room, src=user_id, targ=other_user_id, tok=access_token) self.helper.invite(
room=self.room_id,
src=self.user_id,
targ=self.other_user_id,
tok=self.access_token,
)
# We should have a notification now # We should have a notification now
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/notifications", "/notifications",
access_token=other_access_token, access_token=self.other_access_token,
) )
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body) self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body)
@ -95,3 +94,139 @@ class HTTPPusherTests(HomeserverTestCase):
"invite", "invite",
channel.json_body, channel.json_body,
) )
def test_pagination_of_notifications(self) -> None:
"""
Check that pagination of notifications works.
"""
# Check we start with no pushes
self._request_notifications(from_token=None, limit=1, expected_count=0)
# Send an invite and have the other user join the room.
self.helper.invite(
room=self.room_id,
src=self.user_id,
targ=self.other_user_id,
tok=self.access_token,
)
self.helper.join(self.room_id, self.other_user_id, tok=self.other_access_token)
# Send 5 messages in the room and note down their event IDs.
sent_event_ids = []
for _ in range(5):
resp = self.helper.send_event(
self.room_id,
"m.room.message",
{"body": "honk", "msgtype": "m.text"},
tok=self.access_token,
)
sent_event_ids.append(resp["event_id"])
# We expect to get notifications for messages in reverse order.
# So reverse this list of event IDs to make it easier to compare
# against later.
sent_event_ids.reverse()
# We should have a few notifications now. Let's try and fetch the first 2.
notification_event_ids, _ = self._request_notifications(
from_token=None, limit=2, expected_count=2
)
# Check we got the expected event IDs back.
self.assertEqual(notification_event_ids, sent_event_ids[:2])
# Try requesting again without a 'from' query parameter. We should get the
# same two notifications back.
notification_event_ids, next_token = self._request_notifications(
from_token=None, limit=2, expected_count=2
)
self.assertEqual(notification_event_ids, sent_event_ids[:2])
# Ask for the next 5 notifications, though there should only be
# 4 remaining; the next 3 messages and the invite.
#
# We need to use the "next_token" from the response as the "from"
# query parameter in the next request in order to paginate.
notification_event_ids, next_token = self._request_notifications(
from_token=next_token, limit=5, expected_count=4
)
# Ensure we chop off the invite on the end.
notification_event_ids = notification_event_ids[:-1]
self.assertEqual(notification_event_ids, sent_event_ids[2:])
def _request_notifications(
self, from_token: Optional[str], limit: int, expected_count: int
) -> Tuple[List[str], str]:
"""
Make a request to /notifications to get the latest events to be notified about.
Only the event IDs are returned. The request is made by the "other user".
Args:
from_token: An optional starting parameter.
limit: The maximum number of results to return.
expected_count: The number of events to expect in the response.
Returns:
A list of event IDs that the client should be notified about.
Events are returned newest-first.
"""
# Construct the request path.
path = f"/notifications?limit={limit}"
if from_token is not None:
path += f"&from={from_token}"
channel = self.make_request(
"GET",
path,
access_token=self.other_access_token,
)
self.assertEqual(channel.code, 200)
self.assertEqual(
len(channel.json_body["notifications"]), expected_count, channel.json_body
)
# Extract the necessary data from the response.
next_token = channel.json_body["next_token"]
event_ids = [
event["event"]["event_id"] for event in channel.json_body["notifications"]
]
return event_ids, next_token
def test_parameters(self) -> None:
"""
Test that appropriate errors are returned when query parameters are malformed.
"""
# Test that no parameters are required.
channel = self.make_request(
"GET",
"/notifications",
access_token=self.other_access_token,
)
self.assertEqual(channel.code, 200)
# Test that limit cannot be negative
channel = self.make_request(
"GET",
"/notifications?limit=-1",
access_token=self.other_access_token,
)
self.assertEqual(channel.code, 400)
# Test that the 'limit' parameter must be an integer.
channel = self.make_request(
"GET",
"/notifications?limit=foobar",
access_token=self.other_access_token,
)
self.assertEqual(channel.code, 400)
# Test that the 'from' parameter must be an integer.
channel = self.make_request(
"GET",
"/notifications?from=osborne",
access_token=self.other_access_token,
)
self.assertEqual(channel.code, 400)

View file

@ -447,7 +447,14 @@ class EventChainStoreTestCase(HomeserverTestCase):
) )
# Actually call the function that calculates the auth chain stuff. # Actually call the function that calculates the auth chain stuff.
persist_events_store._persist_event_auth_chain_txn(txn, events) new_event_links = (
persist_events_store.calculate_chain_cover_index_for_events_txn(
txn, events[0].room_id, [e for e in events if e.is_state()]
)
)
persist_events_store._persist_event_auth_chain_txn(
txn, events, new_event_links
)
self.get_success( self.get_success(
persist_events_store.db_pool.runInteraction( persist_events_store.db_pool.runInteraction(

View file

@ -365,12 +365,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
}, },
) )
self.persist_events._persist_event_auth_chain_txn( events = [
txn,
[
cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id])) cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
for event_id in AUTH_GRAPH for event_id in AUTH_GRAPH
], ]
new_event_links = (
self.persist_events.calculate_chain_cover_index_for_events_txn(
txn, room_id, [e for e in events if e.is_state()]
)
)
self.persist_events._persist_event_auth_chain_txn(
txn,
events,
new_event_links,
) )
self.get_success( self.get_success(
@ -628,13 +635,20 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
) )
# Insert all events apart from 'B' # Insert all events apart from 'B'
self.persist_events._persist_event_auth_chain_txn( events = [
txn,
[
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id])) cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
for event_id in auth_graph for event_id in auth_graph
if event_id != "b" if event_id != "b"
], ]
new_event_links = (
self.persist_events.calculate_chain_cover_index_for_events_txn(
txn, room_id, [e for e in events if e.is_state()]
)
)
self.persist_events._persist_event_auth_chain_txn(
txn,
events,
new_event_links,
) )
# Now we insert the event 'B' without a chain cover, by temporarily # Now we insert the event 'B' without a chain cover, by temporarily
@ -647,9 +661,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
updatevalues={"has_auth_chain_index": False}, updatevalues={"has_auth_chain_index": False},
) )
events = [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))]
new_event_links = (
self.persist_events.calculate_chain_cover_index_for_events_txn(
txn, room_id, [e for e in events if e.is_state()]
)
)
self.persist_events._persist_event_auth_chain_txn( self.persist_events._persist_event_auth_chain_txn(
txn, txn, events, new_event_links
[cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
) )
self.store.db_pool.simple_update_txn( self.store.db_pool.simple_update_txn(