mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-19 17:56:19 +03:00
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
This commit is contained in:
commit
4f308ea362
14 changed files with 438 additions and 116 deletions
1
changelog.d/17283.bugfix
Normal file
1
changelog.d/17283.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix a long-standing bug where an invalid 'from' parameter to [`/notifications`](https://spec.matrix.org/v1.10/client-server-api/#get_matrixclientv3notifications) would result in an Internal Server Error.
|
1
changelog.d/17291.misc
Normal file
1
changelog.d/17291.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Do not block event sending/receiving while calulating large event auth chains.
|
2
changelog.d/17294.feature
Normal file
2
changelog.d/17294.feature
Normal file
|
@ -0,0 +1,2 @@
|
|||
`register_new_matrix_user` now supports a --password-file flag, which
|
||||
is useful for scripting.
|
6
debian/changelog
vendored
6
debian/changelog
vendored
|
@ -1,3 +1,9 @@
|
|||
matrix-synapse-py3 (1.109.0+nmu1) UNRELEASED; urgency=medium
|
||||
|
||||
* `register_new_matrix_user` now supports a --password-file flag.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Tue, 18 Jun 2024 13:29:36 +0100
|
||||
|
||||
matrix-synapse-py3 (1.109.0) stable; urgency=medium
|
||||
|
||||
* New synapse release 1.109.0.
|
||||
|
|
8
debian/register_new_matrix_user.ronn
vendored
8
debian/register_new_matrix_user.ronn
vendored
|
@ -31,8 +31,12 @@ A sample YAML file accepted by `register_new_matrix_user` is described below:
|
|||
Local part of the new user. Will prompt if omitted.
|
||||
|
||||
* `-p`, `--password`:
|
||||
New password for user. Will prompt if omitted. Supplying the password
|
||||
on the command line is not recommended. Use the STDIN instead.
|
||||
New password for user. Will prompt if this option and `--password-file` are omitted.
|
||||
Supplying the password on the command line is not recommended.
|
||||
|
||||
* `--password-file`:
|
||||
File containing the new password for user. If set, overrides `--password`.
|
||||
This is a more secure alternative to specifying the password on the command line.
|
||||
|
||||
* `-a`, `--admin`:
|
||||
Register new user as an admin. Will prompt if omitted.
|
||||
|
|
|
@ -173,11 +173,18 @@ def main() -> None:
|
|||
default=None,
|
||||
help="Local part of the new user. Will prompt if omitted.",
|
||||
)
|
||||
parser.add_argument(
|
||||
password_group = parser.add_mutually_exclusive_group()
|
||||
password_group.add_argument(
|
||||
"-p",
|
||||
"--password",
|
||||
default=None,
|
||||
help="New password for user. Will prompt if omitted.",
|
||||
help="New password for user. Will prompt for a password if "
|
||||
"this flag and `--password-file` are both omitted.",
|
||||
)
|
||||
password_group.add_argument(
|
||||
"--password-file",
|
||||
default=None,
|
||||
help="File containing the new password for user. If set, will override `--password`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
|
@ -247,6 +254,11 @@ def main() -> None:
|
|||
print(_NO_SHARED_SECRET_OPTS_ERROR, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if args.password_file:
|
||||
password = _read_file(args.password_file, "password-file").strip()
|
||||
else:
|
||||
password = args.password
|
||||
|
||||
if args.server_url:
|
||||
server_url = args.server_url
|
||||
elif config is not None:
|
||||
|
@ -269,9 +281,7 @@ def main() -> None:
|
|||
if args.admin or args.no_admin:
|
||||
admin = args.admin
|
||||
|
||||
register_new_user(
|
||||
args.user, args.password, server_url, secret, admin, args.user_type
|
||||
)
|
||||
register_new_user(args.user, password, server_url, secret, admin, args.user_type)
|
||||
|
||||
|
||||
def _read_file(file_path: Any, config_path: str) -> str:
|
||||
|
|
|
@ -32,6 +32,7 @@ from synapse.http.servlet import RestServlet, parse_integer, parse_string
|
|||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ...api.errors import SynapseError
|
||||
from ._base import client_patterns
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -56,7 +57,22 @@ class NotificationsServlet(RestServlet):
|
|||
requester = await self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
from_token = parse_string(request, "from", required=False)
|
||||
# While this is intended to be "string" to clients, the 'from' token
|
||||
# is actually based on a numeric ID. So it must parse to an int.
|
||||
from_token_str = parse_string(request, "from", required=False)
|
||||
if from_token_str is not None:
|
||||
# Parse to an integer.
|
||||
try:
|
||||
from_token = int(from_token_str)
|
||||
except ValueError:
|
||||
# If it doesn't parse to an integer, then this cannot possibly be a valid
|
||||
# pagination token, as we only hand out integers.
|
||||
raise SynapseError(
|
||||
400, 'Query parameter "from" contains unrecognised token'
|
||||
)
|
||||
else:
|
||||
from_token = None
|
||||
|
||||
limit = parse_integer(request, "limit", default=50)
|
||||
only = parse_string(request, "only", required=False)
|
||||
|
||||
|
|
|
@ -617,6 +617,17 @@ class EventsPersistenceStorageController:
|
|||
room_id, chunk
|
||||
)
|
||||
|
||||
with Measure(self._clock, "calculate_chain_cover_index_for_events"):
|
||||
# We now calculate chain ID/sequence numbers for any state events we're
|
||||
# persisting. We ignore out of band memberships as we're not in the room
|
||||
# and won't have their auth chain (we'll fix it up later if we join the
|
||||
# room).
|
||||
#
|
||||
# See: docs/auth_chain_difference_algorithm.md
|
||||
new_event_links = await self.persist_events_store.calculate_chain_cover_index_for_events(
|
||||
room_id, [e for e, _ in chunk]
|
||||
)
|
||||
|
||||
await self.persist_events_store._persist_events_and_state_updates(
|
||||
room_id,
|
||||
chunk,
|
||||
|
@ -624,6 +635,7 @@ class EventsPersistenceStorageController:
|
|||
new_forward_extremities=new_forward_extremities,
|
||||
use_negative_stream_ordering=backfilled,
|
||||
inhibit_local_membership_updates=backfilled,
|
||||
new_event_links=new_event_links,
|
||||
)
|
||||
|
||||
return replaced_events
|
||||
|
|
|
@ -1829,7 +1829,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
|||
async def get_push_actions_for_user(
|
||||
self,
|
||||
user_id: str,
|
||||
before: Optional[str] = None,
|
||||
before: Optional[int] = None,
|
||||
limit: int = 50,
|
||||
only_highlight: bool = False,
|
||||
) -> List[UserPushAction]:
|
||||
|
|
|
@ -34,7 +34,6 @@ from typing import (
|
|||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
|
@ -100,6 +99,23 @@ class DeltaState:
|
|||
return not self.to_delete and not self.to_insert and not self.no_longer_in_room
|
||||
|
||||
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class NewEventChainLinks:
|
||||
"""Information about new auth chain links that need to be added to the DB.
|
||||
|
||||
Attributes:
|
||||
chain_id, sequence_number: the IDs corresponding to the event being
|
||||
inserted, and the starting point of the links
|
||||
links: Lists the links that need to be added, 2-tuple of the chain
|
||||
ID/sequence number of the end point of the link.
|
||||
"""
|
||||
|
||||
chain_id: int
|
||||
sequence_number: int
|
||||
|
||||
links: List[Tuple[int, int]] = attr.Factory(list)
|
||||
|
||||
|
||||
class PersistEventsStore:
|
||||
"""Contains all the functions for writing events to the database.
|
||||
|
||||
|
@ -148,6 +164,7 @@ class PersistEventsStore:
|
|||
*,
|
||||
state_delta_for_room: Optional[DeltaState],
|
||||
new_forward_extremities: Optional[Set[str]],
|
||||
new_event_links: Dict[str, NewEventChainLinks],
|
||||
use_negative_stream_ordering: bool = False,
|
||||
inhibit_local_membership_updates: bool = False,
|
||||
) -> None:
|
||||
|
@ -217,6 +234,7 @@ class PersistEventsStore:
|
|||
inhibit_local_membership_updates=inhibit_local_membership_updates,
|
||||
state_delta_for_room=state_delta_for_room,
|
||||
new_forward_extremities=new_forward_extremities,
|
||||
new_event_links=new_event_links,
|
||||
)
|
||||
persist_event_counter.inc(len(events_and_contexts))
|
||||
|
||||
|
@ -243,6 +261,87 @@ class PersistEventsStore:
|
|||
(room_id,), frozenset(new_forward_extremities)
|
||||
)
|
||||
|
||||
async def calculate_chain_cover_index_for_events(
|
||||
self, room_id: str, events: Collection[EventBase]
|
||||
) -> Dict[str, NewEventChainLinks]:
|
||||
# Filter to state events, and ensure there are no duplicates.
|
||||
state_events = []
|
||||
seen_events = set()
|
||||
for event in events:
|
||||
if not event.is_state() or event.event_id in seen_events:
|
||||
continue
|
||||
|
||||
state_events.append(event)
|
||||
seen_events.add(event.event_id)
|
||||
|
||||
if not state_events:
|
||||
return {}
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"_calculate_chain_cover_index_for_events",
|
||||
self.calculate_chain_cover_index_for_events_txn,
|
||||
room_id,
|
||||
state_events,
|
||||
)
|
||||
|
||||
def calculate_chain_cover_index_for_events_txn(
|
||||
self, txn: LoggingTransaction, room_id: str, state_events: Collection[EventBase]
|
||||
) -> Dict[str, NewEventChainLinks]:
|
||||
# We now calculate chain ID/sequence numbers for any state events we're
|
||||
# persisting. We ignore out of band memberships as we're not in the room
|
||||
# and won't have their auth chain (we'll fix it up later if we join the
|
||||
# room).
|
||||
#
|
||||
# See: docs/auth_chain_difference_algorithm.md
|
||||
|
||||
# We ignore legacy rooms that we aren't filling the chain cover index
|
||||
# for.
|
||||
row = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
table="rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcols=("room_id", "has_auth_chain_index"),
|
||||
allow_none=True,
|
||||
)
|
||||
if row is None:
|
||||
return {}
|
||||
|
||||
# Filter out already persisted events.
|
||||
rows = self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="events",
|
||||
column="event_id",
|
||||
iterable=[e.event_id for e in state_events],
|
||||
keyvalues={},
|
||||
retcols=("event_id",),
|
||||
)
|
||||
already_persisted_events = {event_id for event_id, in rows}
|
||||
state_events = [
|
||||
event
|
||||
for event in state_events
|
||||
if event.event_id in already_persisted_events
|
||||
]
|
||||
|
||||
if not state_events:
|
||||
return {}
|
||||
|
||||
# We need to know the type/state_key and auth events of the events we're
|
||||
# calculating chain IDs for. We don't rely on having the full Event
|
||||
# instances as we'll potentially be pulling more events from the DB and
|
||||
# we don't need the overhead of fetching/parsing the full event JSON.
|
||||
event_to_types = {e.event_id: (e.type, e.state_key) for e in state_events}
|
||||
event_to_auth_chain = {e.event_id: e.auth_event_ids() for e in state_events}
|
||||
event_to_room_id = {e.event_id: e.room_id for e in state_events}
|
||||
|
||||
return self._calculate_chain_cover_index(
|
||||
txn,
|
||||
self.db_pool,
|
||||
self.store.event_chain_id_gen,
|
||||
event_to_room_id,
|
||||
event_to_types,
|
||||
event_to_auth_chain,
|
||||
)
|
||||
|
||||
async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
|
||||
"""Filter the supplied list of event_ids to get those which are prev_events of
|
||||
existing (non-outlier/rejected) events.
|
||||
|
@ -358,6 +457,7 @@ class PersistEventsStore:
|
|||
inhibit_local_membership_updates: bool,
|
||||
state_delta_for_room: Optional[DeltaState],
|
||||
new_forward_extremities: Optional[Set[str]],
|
||||
new_event_links: Dict[str, NewEventChainLinks],
|
||||
) -> None:
|
||||
"""Insert some number of room events into the necessary database tables.
|
||||
|
||||
|
@ -466,7 +566,9 @@ class PersistEventsStore:
|
|||
# Insert into event_to_state_groups.
|
||||
self._store_event_state_mappings_txn(txn, events_and_contexts)
|
||||
|
||||
self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts])
|
||||
self._persist_event_auth_chain_txn(
|
||||
txn, [e for e, _ in events_and_contexts], new_event_links
|
||||
)
|
||||
|
||||
# _store_rejected_events_txn filters out any events which were
|
||||
# rejected, and returns the filtered list.
|
||||
|
@ -496,6 +598,7 @@ class PersistEventsStore:
|
|||
self,
|
||||
txn: LoggingTransaction,
|
||||
events: List[EventBase],
|
||||
new_event_links: Dict[str, NewEventChainLinks],
|
||||
) -> None:
|
||||
# We only care about state events, so this if there are no state events.
|
||||
if not any(e.is_state() for e in events):
|
||||
|
@ -519,59 +622,8 @@ class PersistEventsStore:
|
|||
],
|
||||
)
|
||||
|
||||
# We now calculate chain ID/sequence numbers for any state events we're
|
||||
# persisting. We ignore out of band memberships as we're not in the room
|
||||
# and won't have their auth chain (we'll fix it up later if we join the
|
||||
# room).
|
||||
#
|
||||
# See: docs/auth_chain_difference_algorithm.md
|
||||
|
||||
# We ignore legacy rooms that we aren't filling the chain cover index
|
||||
# for.
|
||||
rows = cast(
|
||||
List[Tuple[str, Optional[Union[int, bool]]]],
|
||||
self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="rooms",
|
||||
column="room_id",
|
||||
iterable={event.room_id for event in events if event.is_state()},
|
||||
keyvalues={},
|
||||
retcols=("room_id", "has_auth_chain_index"),
|
||||
),
|
||||
)
|
||||
rooms_using_chain_index = {
|
||||
room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index
|
||||
}
|
||||
|
||||
state_events = {
|
||||
event.event_id: event
|
||||
for event in events
|
||||
if event.is_state() and event.room_id in rooms_using_chain_index
|
||||
}
|
||||
|
||||
if not state_events:
|
||||
return
|
||||
|
||||
# We need to know the type/state_key and auth events of the events we're
|
||||
# calculating chain IDs for. We don't rely on having the full Event
|
||||
# instances as we'll potentially be pulling more events from the DB and
|
||||
# we don't need the overhead of fetching/parsing the full event JSON.
|
||||
event_to_types = {
|
||||
e.event_id: (e.type, e.state_key) for e in state_events.values()
|
||||
}
|
||||
event_to_auth_chain = {
|
||||
e.event_id: e.auth_event_ids() for e in state_events.values()
|
||||
}
|
||||
event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
|
||||
|
||||
self._add_chain_cover_index(
|
||||
txn,
|
||||
self.db_pool,
|
||||
self.store.event_chain_id_gen,
|
||||
event_to_room_id,
|
||||
event_to_types,
|
||||
event_to_auth_chain,
|
||||
)
|
||||
if new_event_links:
|
||||
self._persist_chain_cover_index(txn, self.db_pool, new_event_links)
|
||||
|
||||
@classmethod
|
||||
def _add_chain_cover_index(
|
||||
|
@ -583,6 +635,35 @@ class PersistEventsStore:
|
|||
event_to_types: Dict[str, Tuple[str, str]],
|
||||
event_to_auth_chain: Dict[str, StrCollection],
|
||||
) -> None:
|
||||
"""Calculate and persist the chain cover index for the given events.
|
||||
|
||||
Args:
|
||||
event_to_room_id: Event ID to the room ID of the event
|
||||
event_to_types: Event ID to type and state_key of the event
|
||||
event_to_auth_chain: Event ID to list of auth event IDs of the
|
||||
event (events with no auth events can be excluded).
|
||||
"""
|
||||
|
||||
new_event_links = cls._calculate_chain_cover_index(
|
||||
txn,
|
||||
db_pool,
|
||||
event_chain_id_gen,
|
||||
event_to_room_id,
|
||||
event_to_types,
|
||||
event_to_auth_chain,
|
||||
)
|
||||
cls._persist_chain_cover_index(txn, db_pool, new_event_links)
|
||||
|
||||
@classmethod
|
||||
def _calculate_chain_cover_index(
|
||||
cls,
|
||||
txn: LoggingTransaction,
|
||||
db_pool: DatabasePool,
|
||||
event_chain_id_gen: SequenceGenerator,
|
||||
event_to_room_id: Dict[str, str],
|
||||
event_to_types: Dict[str, Tuple[str, str]],
|
||||
event_to_auth_chain: Dict[str, StrCollection],
|
||||
) -> Dict[str, NewEventChainLinks]:
|
||||
"""Calculate the chain cover index for the given events.
|
||||
|
||||
Args:
|
||||
|
@ -590,6 +671,10 @@ class PersistEventsStore:
|
|||
event_to_types: Event ID to type and state_key of the event
|
||||
event_to_auth_chain: Event ID to list of auth event IDs of the
|
||||
event (events with no auth events can be excluded).
|
||||
|
||||
Returns:
|
||||
A mapping with any new auth chain links we need to add, keyed by
|
||||
event ID.
|
||||
"""
|
||||
|
||||
# Map from event ID to chain ID/sequence number.
|
||||
|
@ -708,11 +793,11 @@ class PersistEventsStore:
|
|||
room_id = event_to_room_id.get(event_id)
|
||||
if room_id:
|
||||
e_type, state_key = event_to_types[event_id]
|
||||
db_pool.simple_insert_txn(
|
||||
db_pool.simple_upsert_txn(
|
||||
txn,
|
||||
table="event_auth_chain_to_calculate",
|
||||
keyvalues={"event_id": event_id},
|
||||
values={
|
||||
"event_id": event_id,
|
||||
"room_id": room_id,
|
||||
"type": e_type,
|
||||
"state_key": state_key,
|
||||
|
@ -724,7 +809,7 @@ class PersistEventsStore:
|
|||
break
|
||||
|
||||
if not events_to_calc_chain_id_for:
|
||||
return
|
||||
return {}
|
||||
|
||||
# Allocate chain ID/sequence numbers to each new event.
|
||||
new_chain_tuples = cls._allocate_chain_ids(
|
||||
|
@ -739,23 +824,10 @@ class PersistEventsStore:
|
|||
)
|
||||
chain_map.update(new_chain_tuples)
|
||||
|
||||
db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
table="event_auth_chains",
|
||||
keys=("event_id", "chain_id", "sequence_number"),
|
||||
values=[
|
||||
(event_id, c_id, seq)
|
||||
for event_id, (c_id, seq) in new_chain_tuples.items()
|
||||
],
|
||||
)
|
||||
|
||||
db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="event_auth_chain_to_calculate",
|
||||
keyvalues={},
|
||||
column="event_id",
|
||||
values=new_chain_tuples,
|
||||
)
|
||||
to_return = {
|
||||
event_id: NewEventChainLinks(chain_id, sequence_number)
|
||||
for event_id, (chain_id, sequence_number) in new_chain_tuples.items()
|
||||
}
|
||||
|
||||
# Now we need to calculate any new links between chains caused by
|
||||
# the new events.
|
||||
|
@ -825,10 +897,38 @@ class PersistEventsStore:
|
|||
auth_chain_id, auth_sequence_number = chain_map[auth_id]
|
||||
|
||||
# Step 2a, add link between the event and auth event
|
||||
to_return[event_id].links.append((auth_chain_id, auth_sequence_number))
|
||||
chain_links.add_link(
|
||||
(chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
|
||||
)
|
||||
|
||||
return to_return
|
||||
|
||||
@classmethod
|
||||
def _persist_chain_cover_index(
|
||||
cls,
|
||||
txn: LoggingTransaction,
|
||||
db_pool: DatabasePool,
|
||||
new_event_links: Dict[str, NewEventChainLinks],
|
||||
) -> None:
|
||||
db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
table="event_auth_chains",
|
||||
keys=("event_id", "chain_id", "sequence_number"),
|
||||
values=[
|
||||
(event_id, new_links.chain_id, new_links.sequence_number)
|
||||
for event_id, new_links in new_event_links.items()
|
||||
],
|
||||
)
|
||||
|
||||
db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="event_auth_chain_to_calculate",
|
||||
keyvalues={},
|
||||
column="event_id",
|
||||
values=new_event_links,
|
||||
)
|
||||
|
||||
db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
table="event_auth_chain_links",
|
||||
|
@ -838,7 +938,16 @@ class PersistEventsStore:
|
|||
"target_chain_id",
|
||||
"target_sequence_number",
|
||||
),
|
||||
values=list(chain_links.get_additions()),
|
||||
values=[
|
||||
(
|
||||
new_links.chain_id,
|
||||
new_links.sequence_number,
|
||||
target_chain_id,
|
||||
target_sequence_number,
|
||||
)
|
||||
for new_links in new_event_links.values()
|
||||
for (target_chain_id, target_sequence_number) in new_links.links
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -688,7 +688,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
|
|||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/notifications?from=",
|
||||
"/notifications",
|
||||
access_token=tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
# [This file includes modifications made by New Vector Limited]
|
||||
#
|
||||
#
|
||||
from typing import List, Optional, Tuple
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
@ -48,6 +49,14 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
self.sync_handler = homeserver.get_sync_handler()
|
||||
self.auth_handler = homeserver.get_auth_handler()
|
||||
|
||||
self.user_id = self.register_user("user", "pass")
|
||||
self.access_token = self.login("user", "pass")
|
||||
self.other_user_id = self.register_user("otheruser", "pass")
|
||||
self.other_access_token = self.login("otheruser", "pass")
|
||||
|
||||
# Create a room
|
||||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
# Mock out the calls over federation.
|
||||
fed_transport_client = Mock(spec=["send_transaction"])
|
||||
|
@ -61,32 +70,22 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
"""
|
||||
Local users will get notified for invites
|
||||
"""
|
||||
|
||||
user_id = self.register_user("user", "pass")
|
||||
access_token = self.login("user", "pass")
|
||||
other_user_id = self.register_user("otheruser", "pass")
|
||||
other_access_token = self.login("otheruser", "pass")
|
||||
|
||||
# Create a room
|
||||
room = self.helper.create_room_as(user_id, tok=access_token)
|
||||
|
||||
# Check we start with no pushes
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/notifications",
|
||||
access_token=other_access_token,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual(len(channel.json_body["notifications"]), 0, channel.json_body)
|
||||
self._request_notifications(from_token=None, limit=1, expected_count=0)
|
||||
|
||||
# Send an invite
|
||||
self.helper.invite(room=room, src=user_id, targ=other_user_id, tok=access_token)
|
||||
self.helper.invite(
|
||||
room=self.room_id,
|
||||
src=self.user_id,
|
||||
targ=self.other_user_id,
|
||||
tok=self.access_token,
|
||||
)
|
||||
|
||||
# We should have a notification now
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/notifications",
|
||||
access_token=other_access_token,
|
||||
access_token=self.other_access_token,
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body)
|
||||
|
@ -95,3 +94,139 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
"invite",
|
||||
channel.json_body,
|
||||
)
|
||||
|
||||
def test_pagination_of_notifications(self) -> None:
|
||||
"""
|
||||
Check that pagination of notifications works.
|
||||
"""
|
||||
# Check we start with no pushes
|
||||
self._request_notifications(from_token=None, limit=1, expected_count=0)
|
||||
|
||||
# Send an invite and have the other user join the room.
|
||||
self.helper.invite(
|
||||
room=self.room_id,
|
||||
src=self.user_id,
|
||||
targ=self.other_user_id,
|
||||
tok=self.access_token,
|
||||
)
|
||||
self.helper.join(self.room_id, self.other_user_id, tok=self.other_access_token)
|
||||
|
||||
# Send 5 messages in the room and note down their event IDs.
|
||||
sent_event_ids = []
|
||||
for _ in range(5):
|
||||
resp = self.helper.send_event(
|
||||
self.room_id,
|
||||
"m.room.message",
|
||||
{"body": "honk", "msgtype": "m.text"},
|
||||
tok=self.access_token,
|
||||
)
|
||||
sent_event_ids.append(resp["event_id"])
|
||||
|
||||
# We expect to get notifications for messages in reverse order.
|
||||
# So reverse this list of event IDs to make it easier to compare
|
||||
# against later.
|
||||
sent_event_ids.reverse()
|
||||
|
||||
# We should have a few notifications now. Let's try and fetch the first 2.
|
||||
notification_event_ids, _ = self._request_notifications(
|
||||
from_token=None, limit=2, expected_count=2
|
||||
)
|
||||
|
||||
# Check we got the expected event IDs back.
|
||||
self.assertEqual(notification_event_ids, sent_event_ids[:2])
|
||||
|
||||
# Try requesting again without a 'from' query parameter. We should get the
|
||||
# same two notifications back.
|
||||
notification_event_ids, next_token = self._request_notifications(
|
||||
from_token=None, limit=2, expected_count=2
|
||||
)
|
||||
self.assertEqual(notification_event_ids, sent_event_ids[:2])
|
||||
|
||||
# Ask for the next 5 notifications, though there should only be
|
||||
# 4 remaining; the next 3 messages and the invite.
|
||||
#
|
||||
# We need to use the "next_token" from the response as the "from"
|
||||
# query parameter in the next request in order to paginate.
|
||||
notification_event_ids, next_token = self._request_notifications(
|
||||
from_token=next_token, limit=5, expected_count=4
|
||||
)
|
||||
# Ensure we chop off the invite on the end.
|
||||
notification_event_ids = notification_event_ids[:-1]
|
||||
self.assertEqual(notification_event_ids, sent_event_ids[2:])
|
||||
|
||||
def _request_notifications(
|
||||
self, from_token: Optional[str], limit: int, expected_count: int
|
||||
) -> Tuple[List[str], str]:
|
||||
"""
|
||||
Make a request to /notifications to get the latest events to be notified about.
|
||||
|
||||
Only the event IDs are returned. The request is made by the "other user".
|
||||
|
||||
Args:
|
||||
from_token: An optional starting parameter.
|
||||
limit: The maximum number of results to return.
|
||||
expected_count: The number of events to expect in the response.
|
||||
|
||||
Returns:
|
||||
A list of event IDs that the client should be notified about.
|
||||
Events are returned newest-first.
|
||||
"""
|
||||
# Construct the request path.
|
||||
path = f"/notifications?limit={limit}"
|
||||
if from_token is not None:
|
||||
path += f"&from={from_token}"
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
path,
|
||||
access_token=self.other_access_token,
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(
|
||||
len(channel.json_body["notifications"]), expected_count, channel.json_body
|
||||
)
|
||||
|
||||
# Extract the necessary data from the response.
|
||||
next_token = channel.json_body["next_token"]
|
||||
event_ids = [
|
||||
event["event"]["event_id"] for event in channel.json_body["notifications"]
|
||||
]
|
||||
|
||||
return event_ids, next_token
|
||||
|
||||
def test_parameters(self) -> None:
|
||||
"""
|
||||
Test that appropriate errors are returned when query parameters are malformed.
|
||||
"""
|
||||
# Test that no parameters are required.
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/notifications",
|
||||
access_token=self.other_access_token,
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
# Test that limit cannot be negative
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/notifications?limit=-1",
|
||||
access_token=self.other_access_token,
|
||||
)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
# Test that the 'limit' parameter must be an integer.
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/notifications?limit=foobar",
|
||||
access_token=self.other_access_token,
|
||||
)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
# Test that the 'from' parameter must be an integer.
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/notifications?from=osborne",
|
||||
access_token=self.other_access_token,
|
||||
)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
|
|
@ -447,7 +447,14 @@ class EventChainStoreTestCase(HomeserverTestCase):
|
|||
)
|
||||
|
||||
# Actually call the function that calculates the auth chain stuff.
|
||||
persist_events_store._persist_event_auth_chain_txn(txn, events)
|
||||
new_event_links = (
|
||||
persist_events_store.calculate_chain_cover_index_for_events_txn(
|
||||
txn, events[0].room_id, [e for e in events if e.is_state()]
|
||||
)
|
||||
)
|
||||
persist_events_store._persist_event_auth_chain_txn(
|
||||
txn, events, new_event_links
|
||||
)
|
||||
|
||||
self.get_success(
|
||||
persist_events_store.db_pool.runInteraction(
|
||||
|
|
|
@ -365,12 +365,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||
},
|
||||
)
|
||||
|
||||
events = [
|
||||
cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
|
||||
for event_id in AUTH_GRAPH
|
||||
]
|
||||
new_event_links = (
|
||||
self.persist_events.calculate_chain_cover_index_for_events_txn(
|
||||
txn, room_id, [e for e in events if e.is_state()]
|
||||
)
|
||||
)
|
||||
self.persist_events._persist_event_auth_chain_txn(
|
||||
txn,
|
||||
[
|
||||
cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
|
||||
for event_id in AUTH_GRAPH
|
||||
],
|
||||
events,
|
||||
new_event_links,
|
||||
)
|
||||
|
||||
self.get_success(
|
||||
|
@ -628,13 +635,20 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
# Insert all events apart from 'B'
|
||||
events = [
|
||||
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
|
||||
for event_id in auth_graph
|
||||
if event_id != "b"
|
||||
]
|
||||
new_event_links = (
|
||||
self.persist_events.calculate_chain_cover_index_for_events_txn(
|
||||
txn, room_id, [e for e in events if e.is_state()]
|
||||
)
|
||||
)
|
||||
self.persist_events._persist_event_auth_chain_txn(
|
||||
txn,
|
||||
[
|
||||
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
|
||||
for event_id in auth_graph
|
||||
if event_id != "b"
|
||||
],
|
||||
events,
|
||||
new_event_links,
|
||||
)
|
||||
|
||||
# Now we insert the event 'B' without a chain cover, by temporarily
|
||||
|
@ -647,9 +661,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
|||
updatevalues={"has_auth_chain_index": False},
|
||||
)
|
||||
|
||||
events = [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))]
|
||||
new_event_links = (
|
||||
self.persist_events.calculate_chain_cover_index_for_events_txn(
|
||||
txn, room_id, [e for e in events if e.is_state()]
|
||||
)
|
||||
)
|
||||
self.persist_events._persist_event_auth_chain_txn(
|
||||
txn,
|
||||
[cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
|
||||
txn, events, new_event_links
|
||||
)
|
||||
|
||||
self.store.db_pool.simple_update_txn(
|
||||
|
|
Loading…
Reference in a new issue