mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-20 02:24:54 +03:00
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
This commit is contained in:
commit
4f308ea362
14 changed files with 438 additions and 116 deletions
1
changelog.d/17283.bugfix
Normal file
1
changelog.d/17283.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix a long-standing bug where an invalid 'from' parameter to [`/notifications`](https://spec.matrix.org/v1.10/client-server-api/#get_matrixclientv3notifications) would result in an Internal Server Error.
|
1
changelog.d/17291.misc
Normal file
1
changelog.d/17291.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Do not block event sending/receiving while calulating large event auth chains.
|
2
changelog.d/17294.feature
Normal file
2
changelog.d/17294.feature
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
`register_new_matrix_user` now supports a --password-file flag, which
|
||||||
|
is useful for scripting.
|
6
debian/changelog
vendored
6
debian/changelog
vendored
|
@ -1,3 +1,9 @@
|
||||||
|
matrix-synapse-py3 (1.109.0+nmu1) UNRELEASED; urgency=medium
|
||||||
|
|
||||||
|
* `register_new_matrix_user` now supports a --password-file flag.
|
||||||
|
|
||||||
|
-- Synapse Packaging team <packages@matrix.org> Tue, 18 Jun 2024 13:29:36 +0100
|
||||||
|
|
||||||
matrix-synapse-py3 (1.109.0) stable; urgency=medium
|
matrix-synapse-py3 (1.109.0) stable; urgency=medium
|
||||||
|
|
||||||
* New synapse release 1.109.0.
|
* New synapse release 1.109.0.
|
||||||
|
|
8
debian/register_new_matrix_user.ronn
vendored
8
debian/register_new_matrix_user.ronn
vendored
|
@ -31,8 +31,12 @@ A sample YAML file accepted by `register_new_matrix_user` is described below:
|
||||||
Local part of the new user. Will prompt if omitted.
|
Local part of the new user. Will prompt if omitted.
|
||||||
|
|
||||||
* `-p`, `--password`:
|
* `-p`, `--password`:
|
||||||
New password for user. Will prompt if omitted. Supplying the password
|
New password for user. Will prompt if this option and `--password-file` are omitted.
|
||||||
on the command line is not recommended. Use the STDIN instead.
|
Supplying the password on the command line is not recommended.
|
||||||
|
|
||||||
|
* `--password-file`:
|
||||||
|
File containing the new password for user. If set, overrides `--password`.
|
||||||
|
This is a more secure alternative to specifying the password on the command line.
|
||||||
|
|
||||||
* `-a`, `--admin`:
|
* `-a`, `--admin`:
|
||||||
Register new user as an admin. Will prompt if omitted.
|
Register new user as an admin. Will prompt if omitted.
|
||||||
|
|
|
@ -173,11 +173,18 @@ def main() -> None:
|
||||||
default=None,
|
default=None,
|
||||||
help="Local part of the new user. Will prompt if omitted.",
|
help="Local part of the new user. Will prompt if omitted.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
password_group = parser.add_mutually_exclusive_group()
|
||||||
|
password_group.add_argument(
|
||||||
"-p",
|
"-p",
|
||||||
"--password",
|
"--password",
|
||||||
default=None,
|
default=None,
|
||||||
help="New password for user. Will prompt if omitted.",
|
help="New password for user. Will prompt for a password if "
|
||||||
|
"this flag and `--password-file` are both omitted.",
|
||||||
|
)
|
||||||
|
password_group.add_argument(
|
||||||
|
"--password-file",
|
||||||
|
default=None,
|
||||||
|
help="File containing the new password for user. If set, will override `--password`.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-t",
|
"-t",
|
||||||
|
@ -247,6 +254,11 @@ def main() -> None:
|
||||||
print(_NO_SHARED_SECRET_OPTS_ERROR, file=sys.stderr)
|
print(_NO_SHARED_SECRET_OPTS_ERROR, file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
if args.password_file:
|
||||||
|
password = _read_file(args.password_file, "password-file").strip()
|
||||||
|
else:
|
||||||
|
password = args.password
|
||||||
|
|
||||||
if args.server_url:
|
if args.server_url:
|
||||||
server_url = args.server_url
|
server_url = args.server_url
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
|
@ -269,9 +281,7 @@ def main() -> None:
|
||||||
if args.admin or args.no_admin:
|
if args.admin or args.no_admin:
|
||||||
admin = args.admin
|
admin = args.admin
|
||||||
|
|
||||||
register_new_user(
|
register_new_user(args.user, password, server_url, secret, admin, args.user_type)
|
||||||
args.user, args.password, server_url, secret, admin, args.user_type
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _read_file(file_path: Any, config_path: str) -> str:
|
def _read_file(file_path: Any, config_path: str) -> str:
|
||||||
|
|
|
@ -32,6 +32,7 @@ from synapse.http.servlet import RestServlet, parse_integer, parse_string
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
|
from ...api.errors import SynapseError
|
||||||
from ._base import client_patterns
|
from ._base import client_patterns
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -56,7 +57,22 @@ class NotificationsServlet(RestServlet):
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
from_token = parse_string(request, "from", required=False)
|
# While this is intended to be "string" to clients, the 'from' token
|
||||||
|
# is actually based on a numeric ID. So it must parse to an int.
|
||||||
|
from_token_str = parse_string(request, "from", required=False)
|
||||||
|
if from_token_str is not None:
|
||||||
|
# Parse to an integer.
|
||||||
|
try:
|
||||||
|
from_token = int(from_token_str)
|
||||||
|
except ValueError:
|
||||||
|
# If it doesn't parse to an integer, then this cannot possibly be a valid
|
||||||
|
# pagination token, as we only hand out integers.
|
||||||
|
raise SynapseError(
|
||||||
|
400, 'Query parameter "from" contains unrecognised token'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from_token = None
|
||||||
|
|
||||||
limit = parse_integer(request, "limit", default=50)
|
limit = parse_integer(request, "limit", default=50)
|
||||||
only = parse_string(request, "only", required=False)
|
only = parse_string(request, "only", required=False)
|
||||||
|
|
||||||
|
|
|
@ -617,6 +617,17 @@ class EventsPersistenceStorageController:
|
||||||
room_id, chunk
|
room_id, chunk
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with Measure(self._clock, "calculate_chain_cover_index_for_events"):
|
||||||
|
# We now calculate chain ID/sequence numbers for any state events we're
|
||||||
|
# persisting. We ignore out of band memberships as we're not in the room
|
||||||
|
# and won't have their auth chain (we'll fix it up later if we join the
|
||||||
|
# room).
|
||||||
|
#
|
||||||
|
# See: docs/auth_chain_difference_algorithm.md
|
||||||
|
new_event_links = await self.persist_events_store.calculate_chain_cover_index_for_events(
|
||||||
|
room_id, [e for e, _ in chunk]
|
||||||
|
)
|
||||||
|
|
||||||
await self.persist_events_store._persist_events_and_state_updates(
|
await self.persist_events_store._persist_events_and_state_updates(
|
||||||
room_id,
|
room_id,
|
||||||
chunk,
|
chunk,
|
||||||
|
@ -624,6 +635,7 @@ class EventsPersistenceStorageController:
|
||||||
new_forward_extremities=new_forward_extremities,
|
new_forward_extremities=new_forward_extremities,
|
||||||
use_negative_stream_ordering=backfilled,
|
use_negative_stream_ordering=backfilled,
|
||||||
inhibit_local_membership_updates=backfilled,
|
inhibit_local_membership_updates=backfilled,
|
||||||
|
new_event_links=new_event_links,
|
||||||
)
|
)
|
||||||
|
|
||||||
return replaced_events
|
return replaced_events
|
||||||
|
|
|
@ -1829,7 +1829,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||||
async def get_push_actions_for_user(
|
async def get_push_actions_for_user(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
before: Optional[str] = None,
|
before: Optional[int] = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
only_highlight: bool = False,
|
only_highlight: bool = False,
|
||||||
) -> List[UserPushAction]:
|
) -> List[UserPushAction]:
|
||||||
|
|
|
@ -34,7 +34,6 @@ from typing import (
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -100,6 +99,23 @@ class DeltaState:
|
||||||
return not self.to_delete and not self.to_insert and not self.no_longer_in_room
|
return not self.to_delete and not self.to_insert and not self.no_longer_in_room
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, auto_attribs=True)
|
||||||
|
class NewEventChainLinks:
|
||||||
|
"""Information about new auth chain links that need to be added to the DB.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
chain_id, sequence_number: the IDs corresponding to the event being
|
||||||
|
inserted, and the starting point of the links
|
||||||
|
links: Lists the links that need to be added, 2-tuple of the chain
|
||||||
|
ID/sequence number of the end point of the link.
|
||||||
|
"""
|
||||||
|
|
||||||
|
chain_id: int
|
||||||
|
sequence_number: int
|
||||||
|
|
||||||
|
links: List[Tuple[int, int]] = attr.Factory(list)
|
||||||
|
|
||||||
|
|
||||||
class PersistEventsStore:
|
class PersistEventsStore:
|
||||||
"""Contains all the functions for writing events to the database.
|
"""Contains all the functions for writing events to the database.
|
||||||
|
|
||||||
|
@ -148,6 +164,7 @@ class PersistEventsStore:
|
||||||
*,
|
*,
|
||||||
state_delta_for_room: Optional[DeltaState],
|
state_delta_for_room: Optional[DeltaState],
|
||||||
new_forward_extremities: Optional[Set[str]],
|
new_forward_extremities: Optional[Set[str]],
|
||||||
|
new_event_links: Dict[str, NewEventChainLinks],
|
||||||
use_negative_stream_ordering: bool = False,
|
use_negative_stream_ordering: bool = False,
|
||||||
inhibit_local_membership_updates: bool = False,
|
inhibit_local_membership_updates: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -217,6 +234,7 @@ class PersistEventsStore:
|
||||||
inhibit_local_membership_updates=inhibit_local_membership_updates,
|
inhibit_local_membership_updates=inhibit_local_membership_updates,
|
||||||
state_delta_for_room=state_delta_for_room,
|
state_delta_for_room=state_delta_for_room,
|
||||||
new_forward_extremities=new_forward_extremities,
|
new_forward_extremities=new_forward_extremities,
|
||||||
|
new_event_links=new_event_links,
|
||||||
)
|
)
|
||||||
persist_event_counter.inc(len(events_and_contexts))
|
persist_event_counter.inc(len(events_and_contexts))
|
||||||
|
|
||||||
|
@ -243,6 +261,87 @@ class PersistEventsStore:
|
||||||
(room_id,), frozenset(new_forward_extremities)
|
(room_id,), frozenset(new_forward_extremities)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def calculate_chain_cover_index_for_events(
|
||||||
|
self, room_id: str, events: Collection[EventBase]
|
||||||
|
) -> Dict[str, NewEventChainLinks]:
|
||||||
|
# Filter to state events, and ensure there are no duplicates.
|
||||||
|
state_events = []
|
||||||
|
seen_events = set()
|
||||||
|
for event in events:
|
||||||
|
if not event.is_state() or event.event_id in seen_events:
|
||||||
|
continue
|
||||||
|
|
||||||
|
state_events.append(event)
|
||||||
|
seen_events.add(event.event_id)
|
||||||
|
|
||||||
|
if not state_events:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"_calculate_chain_cover_index_for_events",
|
||||||
|
self.calculate_chain_cover_index_for_events_txn,
|
||||||
|
room_id,
|
||||||
|
state_events,
|
||||||
|
)
|
||||||
|
|
||||||
|
def calculate_chain_cover_index_for_events_txn(
|
||||||
|
self, txn: LoggingTransaction, room_id: str, state_events: Collection[EventBase]
|
||||||
|
) -> Dict[str, NewEventChainLinks]:
|
||||||
|
# We now calculate chain ID/sequence numbers for any state events we're
|
||||||
|
# persisting. We ignore out of band memberships as we're not in the room
|
||||||
|
# and won't have their auth chain (we'll fix it up later if we join the
|
||||||
|
# room).
|
||||||
|
#
|
||||||
|
# See: docs/auth_chain_difference_algorithm.md
|
||||||
|
|
||||||
|
# We ignore legacy rooms that we aren't filling the chain cover index
|
||||||
|
# for.
|
||||||
|
row = self.db_pool.simple_select_one_txn(
|
||||||
|
txn,
|
||||||
|
table="rooms",
|
||||||
|
keyvalues={"room_id": room_id},
|
||||||
|
retcols=("room_id", "has_auth_chain_index"),
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
if row is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Filter out already persisted events.
|
||||||
|
rows = self.db_pool.simple_select_many_txn(
|
||||||
|
txn,
|
||||||
|
table="events",
|
||||||
|
column="event_id",
|
||||||
|
iterable=[e.event_id for e in state_events],
|
||||||
|
keyvalues={},
|
||||||
|
retcols=("event_id",),
|
||||||
|
)
|
||||||
|
already_persisted_events = {event_id for event_id, in rows}
|
||||||
|
state_events = [
|
||||||
|
event
|
||||||
|
for event in state_events
|
||||||
|
if event.event_id in already_persisted_events
|
||||||
|
]
|
||||||
|
|
||||||
|
if not state_events:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# We need to know the type/state_key and auth events of the events we're
|
||||||
|
# calculating chain IDs for. We don't rely on having the full Event
|
||||||
|
# instances as we'll potentially be pulling more events from the DB and
|
||||||
|
# we don't need the overhead of fetching/parsing the full event JSON.
|
||||||
|
event_to_types = {e.event_id: (e.type, e.state_key) for e in state_events}
|
||||||
|
event_to_auth_chain = {e.event_id: e.auth_event_ids() for e in state_events}
|
||||||
|
event_to_room_id = {e.event_id: e.room_id for e in state_events}
|
||||||
|
|
||||||
|
return self._calculate_chain_cover_index(
|
||||||
|
txn,
|
||||||
|
self.db_pool,
|
||||||
|
self.store.event_chain_id_gen,
|
||||||
|
event_to_room_id,
|
||||||
|
event_to_types,
|
||||||
|
event_to_auth_chain,
|
||||||
|
)
|
||||||
|
|
||||||
async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
|
async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
|
||||||
"""Filter the supplied list of event_ids to get those which are prev_events of
|
"""Filter the supplied list of event_ids to get those which are prev_events of
|
||||||
existing (non-outlier/rejected) events.
|
existing (non-outlier/rejected) events.
|
||||||
|
@ -358,6 +457,7 @@ class PersistEventsStore:
|
||||||
inhibit_local_membership_updates: bool,
|
inhibit_local_membership_updates: bool,
|
||||||
state_delta_for_room: Optional[DeltaState],
|
state_delta_for_room: Optional[DeltaState],
|
||||||
new_forward_extremities: Optional[Set[str]],
|
new_forward_extremities: Optional[Set[str]],
|
||||||
|
new_event_links: Dict[str, NewEventChainLinks],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Insert some number of room events into the necessary database tables.
|
"""Insert some number of room events into the necessary database tables.
|
||||||
|
|
||||||
|
@ -466,7 +566,9 @@ class PersistEventsStore:
|
||||||
# Insert into event_to_state_groups.
|
# Insert into event_to_state_groups.
|
||||||
self._store_event_state_mappings_txn(txn, events_and_contexts)
|
self._store_event_state_mappings_txn(txn, events_and_contexts)
|
||||||
|
|
||||||
self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts])
|
self._persist_event_auth_chain_txn(
|
||||||
|
txn, [e for e, _ in events_and_contexts], new_event_links
|
||||||
|
)
|
||||||
|
|
||||||
# _store_rejected_events_txn filters out any events which were
|
# _store_rejected_events_txn filters out any events which were
|
||||||
# rejected, and returns the filtered list.
|
# rejected, and returns the filtered list.
|
||||||
|
@ -496,6 +598,7 @@ class PersistEventsStore:
|
||||||
self,
|
self,
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
events: List[EventBase],
|
events: List[EventBase],
|
||||||
|
new_event_links: Dict[str, NewEventChainLinks],
|
||||||
) -> None:
|
) -> None:
|
||||||
# We only care about state events, so this if there are no state events.
|
# We only care about state events, so this if there are no state events.
|
||||||
if not any(e.is_state() for e in events):
|
if not any(e.is_state() for e in events):
|
||||||
|
@ -519,59 +622,8 @@ class PersistEventsStore:
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# We now calculate chain ID/sequence numbers for any state events we're
|
if new_event_links:
|
||||||
# persisting. We ignore out of band memberships as we're not in the room
|
self._persist_chain_cover_index(txn, self.db_pool, new_event_links)
|
||||||
# and won't have their auth chain (we'll fix it up later if we join the
|
|
||||||
# room).
|
|
||||||
#
|
|
||||||
# See: docs/auth_chain_difference_algorithm.md
|
|
||||||
|
|
||||||
# We ignore legacy rooms that we aren't filling the chain cover index
|
|
||||||
# for.
|
|
||||||
rows = cast(
|
|
||||||
List[Tuple[str, Optional[Union[int, bool]]]],
|
|
||||||
self.db_pool.simple_select_many_txn(
|
|
||||||
txn,
|
|
||||||
table="rooms",
|
|
||||||
column="room_id",
|
|
||||||
iterable={event.room_id for event in events if event.is_state()},
|
|
||||||
keyvalues={},
|
|
||||||
retcols=("room_id", "has_auth_chain_index"),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
rooms_using_chain_index = {
|
|
||||||
room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index
|
|
||||||
}
|
|
||||||
|
|
||||||
state_events = {
|
|
||||||
event.event_id: event
|
|
||||||
for event in events
|
|
||||||
if event.is_state() and event.room_id in rooms_using_chain_index
|
|
||||||
}
|
|
||||||
|
|
||||||
if not state_events:
|
|
||||||
return
|
|
||||||
|
|
||||||
# We need to know the type/state_key and auth events of the events we're
|
|
||||||
# calculating chain IDs for. We don't rely on having the full Event
|
|
||||||
# instances as we'll potentially be pulling more events from the DB and
|
|
||||||
# we don't need the overhead of fetching/parsing the full event JSON.
|
|
||||||
event_to_types = {
|
|
||||||
e.event_id: (e.type, e.state_key) for e in state_events.values()
|
|
||||||
}
|
|
||||||
event_to_auth_chain = {
|
|
||||||
e.event_id: e.auth_event_ids() for e in state_events.values()
|
|
||||||
}
|
|
||||||
event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
|
|
||||||
|
|
||||||
self._add_chain_cover_index(
|
|
||||||
txn,
|
|
||||||
self.db_pool,
|
|
||||||
self.store.event_chain_id_gen,
|
|
||||||
event_to_room_id,
|
|
||||||
event_to_types,
|
|
||||||
event_to_auth_chain,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _add_chain_cover_index(
|
def _add_chain_cover_index(
|
||||||
|
@ -583,6 +635,35 @@ class PersistEventsStore:
|
||||||
event_to_types: Dict[str, Tuple[str, str]],
|
event_to_types: Dict[str, Tuple[str, str]],
|
||||||
event_to_auth_chain: Dict[str, StrCollection],
|
event_to_auth_chain: Dict[str, StrCollection],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Calculate and persist the chain cover index for the given events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_to_room_id: Event ID to the room ID of the event
|
||||||
|
event_to_types: Event ID to type and state_key of the event
|
||||||
|
event_to_auth_chain: Event ID to list of auth event IDs of the
|
||||||
|
event (events with no auth events can be excluded).
|
||||||
|
"""
|
||||||
|
|
||||||
|
new_event_links = cls._calculate_chain_cover_index(
|
||||||
|
txn,
|
||||||
|
db_pool,
|
||||||
|
event_chain_id_gen,
|
||||||
|
event_to_room_id,
|
||||||
|
event_to_types,
|
||||||
|
event_to_auth_chain,
|
||||||
|
)
|
||||||
|
cls._persist_chain_cover_index(txn, db_pool, new_event_links)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _calculate_chain_cover_index(
|
||||||
|
cls,
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
db_pool: DatabasePool,
|
||||||
|
event_chain_id_gen: SequenceGenerator,
|
||||||
|
event_to_room_id: Dict[str, str],
|
||||||
|
event_to_types: Dict[str, Tuple[str, str]],
|
||||||
|
event_to_auth_chain: Dict[str, StrCollection],
|
||||||
|
) -> Dict[str, NewEventChainLinks]:
|
||||||
"""Calculate the chain cover index for the given events.
|
"""Calculate the chain cover index for the given events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -590,6 +671,10 @@ class PersistEventsStore:
|
||||||
event_to_types: Event ID to type and state_key of the event
|
event_to_types: Event ID to type and state_key of the event
|
||||||
event_to_auth_chain: Event ID to list of auth event IDs of the
|
event_to_auth_chain: Event ID to list of auth event IDs of the
|
||||||
event (events with no auth events can be excluded).
|
event (events with no auth events can be excluded).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A mapping with any new auth chain links we need to add, keyed by
|
||||||
|
event ID.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Map from event ID to chain ID/sequence number.
|
# Map from event ID to chain ID/sequence number.
|
||||||
|
@ -708,11 +793,11 @@ class PersistEventsStore:
|
||||||
room_id = event_to_room_id.get(event_id)
|
room_id = event_to_room_id.get(event_id)
|
||||||
if room_id:
|
if room_id:
|
||||||
e_type, state_key = event_to_types[event_id]
|
e_type, state_key = event_to_types[event_id]
|
||||||
db_pool.simple_insert_txn(
|
db_pool.simple_upsert_txn(
|
||||||
txn,
|
txn,
|
||||||
table="event_auth_chain_to_calculate",
|
table="event_auth_chain_to_calculate",
|
||||||
|
keyvalues={"event_id": event_id},
|
||||||
values={
|
values={
|
||||||
"event_id": event_id,
|
|
||||||
"room_id": room_id,
|
"room_id": room_id,
|
||||||
"type": e_type,
|
"type": e_type,
|
||||||
"state_key": state_key,
|
"state_key": state_key,
|
||||||
|
@ -724,7 +809,7 @@ class PersistEventsStore:
|
||||||
break
|
break
|
||||||
|
|
||||||
if not events_to_calc_chain_id_for:
|
if not events_to_calc_chain_id_for:
|
||||||
return
|
return {}
|
||||||
|
|
||||||
# Allocate chain ID/sequence numbers to each new event.
|
# Allocate chain ID/sequence numbers to each new event.
|
||||||
new_chain_tuples = cls._allocate_chain_ids(
|
new_chain_tuples = cls._allocate_chain_ids(
|
||||||
|
@ -739,23 +824,10 @@ class PersistEventsStore:
|
||||||
)
|
)
|
||||||
chain_map.update(new_chain_tuples)
|
chain_map.update(new_chain_tuples)
|
||||||
|
|
||||||
db_pool.simple_insert_many_txn(
|
to_return = {
|
||||||
txn,
|
event_id: NewEventChainLinks(chain_id, sequence_number)
|
||||||
table="event_auth_chains",
|
for event_id, (chain_id, sequence_number) in new_chain_tuples.items()
|
||||||
keys=("event_id", "chain_id", "sequence_number"),
|
}
|
||||||
values=[
|
|
||||||
(event_id, c_id, seq)
|
|
||||||
for event_id, (c_id, seq) in new_chain_tuples.items()
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
db_pool.simple_delete_many_txn(
|
|
||||||
txn,
|
|
||||||
table="event_auth_chain_to_calculate",
|
|
||||||
keyvalues={},
|
|
||||||
column="event_id",
|
|
||||||
values=new_chain_tuples,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Now we need to calculate any new links between chains caused by
|
# Now we need to calculate any new links between chains caused by
|
||||||
# the new events.
|
# the new events.
|
||||||
|
@ -825,10 +897,38 @@ class PersistEventsStore:
|
||||||
auth_chain_id, auth_sequence_number = chain_map[auth_id]
|
auth_chain_id, auth_sequence_number = chain_map[auth_id]
|
||||||
|
|
||||||
# Step 2a, add link between the event and auth event
|
# Step 2a, add link between the event and auth event
|
||||||
|
to_return[event_id].links.append((auth_chain_id, auth_sequence_number))
|
||||||
chain_links.add_link(
|
chain_links.add_link(
|
||||||
(chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
|
(chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return to_return
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _persist_chain_cover_index(
|
||||||
|
cls,
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
db_pool: DatabasePool,
|
||||||
|
new_event_links: Dict[str, NewEventChainLinks],
|
||||||
|
) -> None:
|
||||||
|
db_pool.simple_insert_many_txn(
|
||||||
|
txn,
|
||||||
|
table="event_auth_chains",
|
||||||
|
keys=("event_id", "chain_id", "sequence_number"),
|
||||||
|
values=[
|
||||||
|
(event_id, new_links.chain_id, new_links.sequence_number)
|
||||||
|
for event_id, new_links in new_event_links.items()
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
db_pool.simple_delete_many_txn(
|
||||||
|
txn,
|
||||||
|
table="event_auth_chain_to_calculate",
|
||||||
|
keyvalues={},
|
||||||
|
column="event_id",
|
||||||
|
values=new_event_links,
|
||||||
|
)
|
||||||
|
|
||||||
db_pool.simple_insert_many_txn(
|
db_pool.simple_insert_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="event_auth_chain_links",
|
table="event_auth_chain_links",
|
||||||
|
@ -838,7 +938,16 @@ class PersistEventsStore:
|
||||||
"target_chain_id",
|
"target_chain_id",
|
||||||
"target_sequence_number",
|
"target_sequence_number",
|
||||||
),
|
),
|
||||||
values=list(chain_links.get_additions()),
|
values=[
|
||||||
|
(
|
||||||
|
new_links.chain_id,
|
||||||
|
new_links.sequence_number,
|
||||||
|
target_chain_id,
|
||||||
|
target_sequence_number,
|
||||||
|
)
|
||||||
|
for new_links in new_event_links.values()
|
||||||
|
for (target_chain_id, target_sequence_number) in new_links.links
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -688,7 +688,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
"/notifications?from=",
|
"/notifications",
|
||||||
access_token=tok,
|
access_token=tok,
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
# [This file includes modifications made by New Vector Limited]
|
# [This file includes modifications made by New Vector Limited]
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
from unittest.mock import AsyncMock, Mock
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
@ -48,6 +49,14 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
self.sync_handler = homeserver.get_sync_handler()
|
self.sync_handler = homeserver.get_sync_handler()
|
||||||
self.auth_handler = homeserver.get_auth_handler()
|
self.auth_handler = homeserver.get_auth_handler()
|
||||||
|
|
||||||
|
self.user_id = self.register_user("user", "pass")
|
||||||
|
self.access_token = self.login("user", "pass")
|
||||||
|
self.other_user_id = self.register_user("otheruser", "pass")
|
||||||
|
self.other_access_token = self.login("otheruser", "pass")
|
||||||
|
|
||||||
|
# Create a room
|
||||||
|
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
|
||||||
|
|
||||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
# Mock out the calls over federation.
|
# Mock out the calls over federation.
|
||||||
fed_transport_client = Mock(spec=["send_transaction"])
|
fed_transport_client = Mock(spec=["send_transaction"])
|
||||||
|
@ -61,32 +70,22 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
Local users will get notified for invites
|
Local users will get notified for invites
|
||||||
"""
|
"""
|
||||||
|
|
||||||
user_id = self.register_user("user", "pass")
|
|
||||||
access_token = self.login("user", "pass")
|
|
||||||
other_user_id = self.register_user("otheruser", "pass")
|
|
||||||
other_access_token = self.login("otheruser", "pass")
|
|
||||||
|
|
||||||
# Create a room
|
|
||||||
room = self.helper.create_room_as(user_id, tok=access_token)
|
|
||||||
|
|
||||||
# Check we start with no pushes
|
# Check we start with no pushes
|
||||||
channel = self.make_request(
|
self._request_notifications(from_token=None, limit=1, expected_count=0)
|
||||||
"GET",
|
|
||||||
"/notifications",
|
|
||||||
access_token=other_access_token,
|
|
||||||
)
|
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
|
||||||
self.assertEqual(len(channel.json_body["notifications"]), 0, channel.json_body)
|
|
||||||
|
|
||||||
# Send an invite
|
# Send an invite
|
||||||
self.helper.invite(room=room, src=user_id, targ=other_user_id, tok=access_token)
|
self.helper.invite(
|
||||||
|
room=self.room_id,
|
||||||
|
src=self.user_id,
|
||||||
|
targ=self.other_user_id,
|
||||||
|
tok=self.access_token,
|
||||||
|
)
|
||||||
|
|
||||||
# We should have a notification now
|
# We should have a notification now
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
"/notifications",
|
"/notifications",
|
||||||
access_token=other_access_token,
|
access_token=self.other_access_token,
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body)
|
self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body)
|
||||||
|
@ -95,3 +94,139 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
"invite",
|
"invite",
|
||||||
channel.json_body,
|
channel.json_body,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_pagination_of_notifications(self) -> None:
|
||||||
|
"""
|
||||||
|
Check that pagination of notifications works.
|
||||||
|
"""
|
||||||
|
# Check we start with no pushes
|
||||||
|
self._request_notifications(from_token=None, limit=1, expected_count=0)
|
||||||
|
|
||||||
|
# Send an invite and have the other user join the room.
|
||||||
|
self.helper.invite(
|
||||||
|
room=self.room_id,
|
||||||
|
src=self.user_id,
|
||||||
|
targ=self.other_user_id,
|
||||||
|
tok=self.access_token,
|
||||||
|
)
|
||||||
|
self.helper.join(self.room_id, self.other_user_id, tok=self.other_access_token)
|
||||||
|
|
||||||
|
# Send 5 messages in the room and note down their event IDs.
|
||||||
|
sent_event_ids = []
|
||||||
|
for _ in range(5):
|
||||||
|
resp = self.helper.send_event(
|
||||||
|
self.room_id,
|
||||||
|
"m.room.message",
|
||||||
|
{"body": "honk", "msgtype": "m.text"},
|
||||||
|
tok=self.access_token,
|
||||||
|
)
|
||||||
|
sent_event_ids.append(resp["event_id"])
|
||||||
|
|
||||||
|
# We expect to get notifications for messages in reverse order.
|
||||||
|
# So reverse this list of event IDs to make it easier to compare
|
||||||
|
# against later.
|
||||||
|
sent_event_ids.reverse()
|
||||||
|
|
||||||
|
# We should have a few notifications now. Let's try and fetch the first 2.
|
||||||
|
notification_event_ids, _ = self._request_notifications(
|
||||||
|
from_token=None, limit=2, expected_count=2
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check we got the expected event IDs back.
|
||||||
|
self.assertEqual(notification_event_ids, sent_event_ids[:2])
|
||||||
|
|
||||||
|
# Try requesting again without a 'from' query parameter. We should get the
|
||||||
|
# same two notifications back.
|
||||||
|
notification_event_ids, next_token = self._request_notifications(
|
||||||
|
from_token=None, limit=2, expected_count=2
|
||||||
|
)
|
||||||
|
self.assertEqual(notification_event_ids, sent_event_ids[:2])
|
||||||
|
|
||||||
|
# Ask for the next 5 notifications, though there should only be
|
||||||
|
# 4 remaining; the next 3 messages and the invite.
|
||||||
|
#
|
||||||
|
# We need to use the "next_token" from the response as the "from"
|
||||||
|
# query parameter in the next request in order to paginate.
|
||||||
|
notification_event_ids, next_token = self._request_notifications(
|
||||||
|
from_token=next_token, limit=5, expected_count=4
|
||||||
|
)
|
||||||
|
# Ensure we chop off the invite on the end.
|
||||||
|
notification_event_ids = notification_event_ids[:-1]
|
||||||
|
self.assertEqual(notification_event_ids, sent_event_ids[2:])
|
||||||
|
|
||||||
|
def _request_notifications(
|
||||||
|
self, from_token: Optional[str], limit: int, expected_count: int
|
||||||
|
) -> Tuple[List[str], str]:
|
||||||
|
"""
|
||||||
|
Make a request to /notifications to get the latest events to be notified about.
|
||||||
|
|
||||||
|
Only the event IDs are returned. The request is made by the "other user".
|
||||||
|
|
||||||
|
Args:
|
||||||
|
from_token: An optional starting parameter.
|
||||||
|
limit: The maximum number of results to return.
|
||||||
|
expected_count: The number of events to expect in the response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of event IDs that the client should be notified about.
|
||||||
|
Events are returned newest-first.
|
||||||
|
"""
|
||||||
|
# Construct the request path.
|
||||||
|
path = f"/notifications?limit={limit}"
|
||||||
|
if from_token is not None:
|
||||||
|
path += f"&from={from_token}"
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
path,
|
||||||
|
access_token=self.other_access_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(channel.code, 200)
|
||||||
|
self.assertEqual(
|
||||||
|
len(channel.json_body["notifications"]), expected_count, channel.json_body
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the necessary data from the response.
|
||||||
|
next_token = channel.json_body["next_token"]
|
||||||
|
event_ids = [
|
||||||
|
event["event"]["event_id"] for event in channel.json_body["notifications"]
|
||||||
|
]
|
||||||
|
|
||||||
|
return event_ids, next_token
|
||||||
|
|
||||||
|
def test_parameters(self) -> None:
|
||||||
|
"""
|
||||||
|
Test that appropriate errors are returned when query parameters are malformed.
|
||||||
|
"""
|
||||||
|
# Test that no parameters are required.
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/notifications",
|
||||||
|
access_token=self.other_access_token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 200)
|
||||||
|
|
||||||
|
# Test that limit cannot be negative
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/notifications?limit=-1",
|
||||||
|
access_token=self.other_access_token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 400)
|
||||||
|
|
||||||
|
# Test that the 'limit' parameter must be an integer.
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/notifications?limit=foobar",
|
||||||
|
access_token=self.other_access_token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 400)
|
||||||
|
|
||||||
|
# Test that the 'from' parameter must be an integer.
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/notifications?from=osborne",
|
||||||
|
access_token=self.other_access_token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 400)
|
||||||
|
|
|
@ -447,7 +447,14 @@ class EventChainStoreTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Actually call the function that calculates the auth chain stuff.
|
# Actually call the function that calculates the auth chain stuff.
|
||||||
persist_events_store._persist_event_auth_chain_txn(txn, events)
|
new_event_links = (
|
||||||
|
persist_events_store.calculate_chain_cover_index_for_events_txn(
|
||||||
|
txn, events[0].room_id, [e for e in events if e.is_state()]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
persist_events_store._persist_event_auth_chain_txn(
|
||||||
|
txn, events, new_event_links
|
||||||
|
)
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
persist_events_store.db_pool.runInteraction(
|
persist_events_store.db_pool.runInteraction(
|
||||||
|
|
|
@ -365,12 +365,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.persist_events._persist_event_auth_chain_txn(
|
events = [
|
||||||
txn,
|
|
||||||
[
|
|
||||||
cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
|
cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
|
||||||
for event_id in AUTH_GRAPH
|
for event_id in AUTH_GRAPH
|
||||||
],
|
]
|
||||||
|
new_event_links = (
|
||||||
|
self.persist_events.calculate_chain_cover_index_for_events_txn(
|
||||||
|
txn, room_id, [e for e in events if e.is_state()]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.persist_events._persist_event_auth_chain_txn(
|
||||||
|
txn,
|
||||||
|
events,
|
||||||
|
new_event_links,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
@ -628,13 +635,20 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Insert all events apart from 'B'
|
# Insert all events apart from 'B'
|
||||||
self.persist_events._persist_event_auth_chain_txn(
|
events = [
|
||||||
txn,
|
|
||||||
[
|
|
||||||
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
|
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
|
||||||
for event_id in auth_graph
|
for event_id in auth_graph
|
||||||
if event_id != "b"
|
if event_id != "b"
|
||||||
],
|
]
|
||||||
|
new_event_links = (
|
||||||
|
self.persist_events.calculate_chain_cover_index_for_events_txn(
|
||||||
|
txn, room_id, [e for e in events if e.is_state()]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.persist_events._persist_event_auth_chain_txn(
|
||||||
|
txn,
|
||||||
|
events,
|
||||||
|
new_event_links,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now we insert the event 'B' without a chain cover, by temporarily
|
# Now we insert the event 'B' without a chain cover, by temporarily
|
||||||
|
@ -647,9 +661,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
updatevalues={"has_auth_chain_index": False},
|
updatevalues={"has_auth_chain_index": False},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
events = [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))]
|
||||||
|
new_event_links = (
|
||||||
|
self.persist_events.calculate_chain_cover_index_for_events_txn(
|
||||||
|
txn, room_id, [e for e in events if e.is_state()]
|
||||||
|
)
|
||||||
|
)
|
||||||
self.persist_events._persist_event_auth_chain_txn(
|
self.persist_events._persist_event_auth_chain_txn(
|
||||||
txn,
|
txn, events, new_event_links
|
||||||
[cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.store.db_pool.simple_update_txn(
|
self.store.db_pool.simple_update_txn(
|
||||||
|
|
Loading…
Reference in a new issue