Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2024-04-08 10:11:02 +01:00
commit 6e95084685
16 changed files with 724 additions and 230 deletions

View file

@ -1,3 +1,10 @@
# Synapse 1.104.0 (2024-04-02)
### Bugfixes
- Fix regression when using OIDC provider. Introduced in v1.104.0rc1. ([\#17031](https://github.com/element-hq/synapse/issues/17031))
# Synapse 1.104.0rc1 (2024-03-26)
### Features

1
changelog.d/16930.bugfix Normal file
View file

@ -0,0 +1 @@
Fix various long-standing bugs which could cause incorrect state to be returned from `/sync` in certain situations.

1
changelog.d/16932.bugfix Normal file
View file

@ -0,0 +1 @@
Fix various long-standing bugs which could cause incorrect state to be returned from `/sync` in certain situations.

1
changelog.d/16942.bugfix Normal file
View file

@ -0,0 +1 @@
Fix various long-standing bugs which could cause incorrect state to be returned from `/sync` in certain situations.

View file

@ -1 +0,0 @@
OIDC: try to JWT decode userinfo response if JSON parsing failed.

1
changelog.d/17044.misc Normal file
View file

@ -0,0 +1 @@
Refactor auth chain fetching to reduce duplication.

1
changelog.d/17045.misc Normal file
View file

@ -0,0 +1 @@
Improve database performance by adding a missing index to `access_tokens.refresh_token_id`.

1
changelog.d/17049.misc Normal file
View file

@ -0,0 +1 @@
Improve database performance by reducing number of receipts fetched when sending push notifications.

6
debian/changelog vendored
View file

@ -1,3 +1,9 @@
matrix-synapse-py3 (1.104.0) stable; urgency=medium
* New Synapse release 1.104.0.
-- Synapse Packaging team <packages@matrix.org> Tue, 02 Apr 2024 17:15:45 +0100
matrix-synapse-py3 (1.104.0~rc1) stable; urgency=medium
* New Synapse release 1.104.0rc1.

View file

@ -96,7 +96,7 @@ module-name = "synapse.synapse_rust"
[tool.poetry]
name = "matrix-synapse"
version = "1.104.0rc1"
version = "1.104.0"
description = "Homeserver for the Matrix decentralised comms protocol"
authors = ["Matrix.org Team and Contributors <packages@matrix.org>"]
license = "AGPL-3.0-or-later"

View file

@ -953,7 +953,7 @@ class SyncHandler:
batch: TimelineBatch,
sync_config: SyncConfig,
since_token: Optional[StreamToken],
now_token: StreamToken,
end_token: StreamToken,
full_state: bool,
) -> MutableStateMap[EventBase]:
"""Works out the difference in state between the end of the previous sync and
@ -964,7 +964,9 @@ class SyncHandler:
batch: The timeline batch for the room that will be sent to the user.
sync_config:
since_token: Token of the end of the previous batch. May be `None`.
now_token: Token of the end of the current batch.
end_token: Token of the end of the current batch. Normally this will be
the same as the global "now_token", but if the user has left the room,
the point just after their leave event.
full_state: Whether to force returning the full state.
`lazy_load_members` still applies when `full_state` is `True`.
@ -1044,7 +1046,7 @@ class SyncHandler:
room_id,
sync_config.user,
batch,
now_token,
end_token,
members_to_fetch,
timeline_state,
)
@ -1058,7 +1060,7 @@ class SyncHandler:
room_id,
batch,
since_token,
now_token,
end_token,
members_to_fetch,
timeline_state,
)
@ -1130,7 +1132,7 @@ class SyncHandler:
room_id: str,
syncing_user: UserID,
batch: TimelineBatch,
now_token: StreamToken,
end_token: StreamToken,
members_to_fetch: Optional[Set[str]],
timeline_state: StateMap[str],
) -> StateMap[str]:
@ -1143,7 +1145,9 @@ class SyncHandler:
room_id: The room we are calculating for.
syncing_user: The user that is calling `/sync`.
batch: The timeline batch for the room that will be sent to the user.
now_token: Token of the end of the current batch.
end_token: Token of the end of the current batch. Normally this will be
the same as the global "now_token", but if the user has left the room,
the point just after their leave event.
members_to_fetch: If lazy-loading is enabled, the memberships needed for
events in the timeline.
timeline_state: The contribution to the room state from state events in
@ -1183,15 +1187,16 @@ class SyncHandler:
await_full_state = True
lazy_load_members = False
if batch:
state_at_timeline_end = (
await self._state_storage_controller.get_state_ids_for_event(
batch.events[-1].event_id,
state_at_timeline_end = await self.get_state_at(
room_id,
stream_position=end_token,
state_filter=state_filter,
await_full_state=await_full_state,
)
)
if batch:
# Strictly speaking, this returns the state *after* the first event in the
# timeline, but that is good enough here.
state_at_timeline_start = (
await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id,
@ -1200,13 +1205,6 @@ class SyncHandler:
)
)
else:
state_at_timeline_end = await self.get_state_at(
room_id,
stream_position=now_token,
state_filter=state_filter,
await_full_state=await_full_state,
)
state_at_timeline_start = state_at_timeline_end
state_ids = _calculate_state(
@ -1223,7 +1221,7 @@ class SyncHandler:
room_id: str,
batch: TimelineBatch,
since_token: StreamToken,
now_token: StreamToken,
end_token: StreamToken,
members_to_fetch: Optional[Set[str]],
timeline_state: StateMap[str],
) -> StateMap[str]:
@ -1239,7 +1237,9 @@ class SyncHandler:
room_id: The room we are calculating for.
batch: The timeline batch for the room that will be sent to the user.
since_token: Token of the end of the previous batch.
now_token: Token of the end of the current batch.
end_token: Token of the end of the current batch. Normally this will be
the same as the global "now_token", but if the user has left the room,
the point just after their leave event.
members_to_fetch: If lazy-loading is enabled, the memberships needed for
events in the timeline. Otherwise, `None`.
timeline_state: The contribution to the room state from state events in
@ -1259,7 +1259,6 @@ class SyncHandler:
await_full_state = True
lazy_load_members = False
if batch.limited:
if batch:
state_at_timeline_start = (
await self._state_storage_controller.get_state_ids_for_event(
@ -1273,11 +1272,12 @@ class SyncHandler:
# the recent events.
state_at_timeline_start = await self.get_state_at(
room_id,
stream_position=now_token,
stream_position=end_token,
state_filter=state_filter,
await_full_state=await_full_state,
)
if batch.limited:
# for now, we disable LL for gappy syncs - see
# https://github.com/vector-im/riot-web/issues/7211#issuecomment-419976346
# N.B. this slows down incr syncs as we are now processing way
@ -1299,20 +1299,9 @@ class SyncHandler:
await_full_state=await_full_state,
)
if batch:
state_at_timeline_end = (
await self._state_storage_controller.get_state_ids_for_event(
batch.events[-1].event_id,
state_filter=state_filter,
await_full_state=await_full_state,
)
)
else:
# We can get here if the user has ignored the senders of all
# the recent events.
state_at_timeline_end = await self.get_state_at(
room_id,
stream_position=now_token,
stream_position=end_token,
state_filter=state_filter,
await_full_state=await_full_state,
)
@ -1324,26 +1313,7 @@ class SyncHandler:
previous_timeline_end=state_at_previous_sync,
lazy_load_members=lazy_load_members,
)
else:
state_ids = {}
if lazy_load_members:
if members_to_fetch and batch.events:
# We're returning an incremental sync, with no
# "gap" since the previous sync, so normally there would be
# no state to return.
# But we're lazy-loading, so the client might need some more
# member events to understand the events in this timeline.
# So we fish out all the member events corresponding to the
# timeline here. The caller will then dedupe any redundant ones.
state_ids = await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id,
# we only want members!
state_filter=StateFilter.from_types(
(EventTypes.Member, member) for member in members_to_fetch
),
await_full_state=False,
)
return state_ids
async def _find_missing_partial_state_memberships(
@ -2344,6 +2314,7 @@ class SyncHandler:
full_state=False,
since_token=since_token,
upto_token=leave_token,
end_token=leave_token,
out_of_band=leave_event.internal_metadata.is_out_of_band_membership(),
)
)
@ -2381,6 +2352,7 @@ class SyncHandler:
full_state=False,
since_token=None if newly_joined else since_token,
upto_token=prev_batch_token,
end_token=now_token,
)
else:
entry = RoomSyncResultBuilder(
@ -2391,6 +2363,7 @@ class SyncHandler:
full_state=False,
since_token=since_token,
upto_token=since_token,
end_token=now_token,
)
room_entries.append(entry)
@ -2449,6 +2422,7 @@ class SyncHandler:
full_state=True,
since_token=since_token,
upto_token=now_token,
end_token=now_token,
)
)
elif event.membership == Membership.INVITE:
@ -2478,6 +2452,7 @@ class SyncHandler:
full_state=True,
since_token=since_token,
upto_token=leave_token,
end_token=leave_token,
)
)
@ -2548,6 +2523,7 @@ class SyncHandler:
{
"since_token": since_token,
"upto_token": upto_token,
"end_token": room_builder.end_token,
}
)
@ -2621,7 +2597,7 @@ class SyncHandler:
batch,
sync_config,
since_token,
now_token,
room_builder.end_token,
full_state=full_state,
)
else:
@ -2781,6 +2757,61 @@ def _calculate_state(
e for t, e in timeline_start.items() if t[0] == EventTypes.Member
)
# Naively, we would just return the difference between the state at the start
# of the timeline (`timeline_start_ids`) and that at the end of the previous sync
# (`previous_timeline_end_ids`). However, that fails in the presence of forks in
# the DAG.
#
# For example, consider a DAG such as the following:
#
# E1
# ↗ ↖
# | S2
# | ↑
# --|------|----
# | |
# E3 |
# ↖ /
# E4
#
# ... and a filter that means we only return 2 events, represented by the dashed
# horizontal line. Assuming S2 was *not* included in the previous sync, we need to
# include it in the `state` section.
#
# Note that the state at the start of the timeline (E3) does not include S2. So,
# to make sure it gets included in the calculation here, we actually look at
# the state at the *end* of the timeline, and subtract any events that are present
# in the timeline.
#
# ----------
#
# Aside 1: You may then wonder if we need to include `timeline_start` in the
# calculation. Consider a linear DAG:
#
# E1
# ↑
# S2
# ↑
# ----|------
# |
# E3
# ↑
# S4
# ↑
# E5
#
# ... where S2 and S4 change the same piece of state; and where we have a filter
# that returns 3 events (E3, S4, E5). We still need to tell the client about S2,
# because it might affect the display of E3. However, the state at the end of the
# timeline only tells us about S4; if we don't inspect `timeline_start` we won't
# find out about S2.
#
# (There are yet more complicated cases in which a state event is excluded from the
# timeline, but whose effect actually lands in the DAG in the *middle* of the
# timeline. We have no way to represent that in the /sync response, and we don't
# even try; it is ether omitted or plonked into `state` as if it were at the start
# of the timeline, depending on what else is in the timeline.)
state_ids = (
(timeline_end_ids | timeline_start_ids)
- previous_timeline_end_ids
@ -2883,13 +2914,30 @@ class RoomSyncResultBuilder:
Attributes:
room_id
rtype: One of `"joined"` or `"archived"`
events: List of events to include in the room (more events may be added
when generating result).
newly_joined: If the user has newly joined the room
full_state: Whether the full state should be sent in result
since_token: Earliest point to return events from, or None
upto_token: Latest point to return events from.
upto_token: Latest point to return events from. If `events` is populated,
this is set to the token at the start of `events`
end_token: The last point in the timeline that the client should see events
from. Normally this will be the same as the global `now_token`, but in
the case of rooms where the user has left the room, this will be the point
just after their leave event.
This is used in the calculation of the state which is returned in `state`:
any state changes *up to* `end_token` (and not beyond!) which are not
reflected in the timeline need to be returned in `state`.
out_of_band: whether the events in the room are "out of band" events
and the server isn't in the room.
"""
@ -2901,5 +2949,5 @@ class RoomSyncResultBuilder:
full_state: bool
since_token: Optional[StreamToken]
upto_token: StreamToken
end_token: StreamToken
out_of_band: bool = False

View file

@ -27,6 +27,7 @@ from typing import (
Collection,
Dict,
FrozenSet,
Generator,
Iterable,
List,
Optional,
@ -279,64 +280,16 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Now we look up all links for the chains we have, adding chains that
# are reachable from any event.
#
# This query is structured to first get all chain IDs reachable, and
# then pull out all links from those chains. This does pull out more
# rows than is strictly necessary, however there isn't a way of
# structuring the recursive part of query to pull out the links without
# also returning large quantities of redundant data (which can make it a
# lot slower).
sql = """
WITH RECURSIVE links(chain_id) AS (
SELECT
DISTINCT origin_chain_id
FROM event_auth_chain_links WHERE %s
UNION
SELECT
target_chain_id
FROM event_auth_chain_links
INNER JOIN links ON (chain_id = origin_chain_id)
)
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM links
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
"""
# A map from chain ID to max sequence number *reachable* from any event ID.
chains: Dict[int, int] = {}
# Add all linked chains reachable from initial set of chains.
chains_to_fetch = set(event_chains.keys())
while chains_to_fetch:
batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
chains_to_fetch.difference_update(batch2)
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)
links: Dict[int, List[Tuple[int, int, int]]] = {}
for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
links.setdefault(origin_chain_id, []).append(
(origin_sequence_number, target_chain_id, target_sequence_number)
)
for links in self._get_chain_links(txn, set(event_chains.keys())):
for chain_id in links:
if chain_id not in event_chains:
continue
_materialize(chain_id, event_chains[chain_id], links, chains)
chains_to_fetch.difference_update(chains)
# Add the initial set of chains, excluding the sequence corresponding to
# initial event.
for chain_id, seq_no in event_chains.items():
@ -380,6 +333,68 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return results
@classmethod
def _get_chain_links(
cls, txn: LoggingTransaction, chains_to_fetch: Set[int]
) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]:
"""Fetch all auth chain links from the given set of chains, and all
links from those chains, recursively.
Note: This may return links that are not reachable from the given
chains.
Returns a generator that produces dicts from origin chain ID to 3-tuple
of origin sequence number, target chain ID and target sequence number.
"""
# This query is structured to first get all chain IDs reachable, and
# then pull out all links from those chains. This does pull out more
# rows than is strictly necessary, however there isn't a way of
# structuring the recursive part of query to pull out the links without
# also returning large quantities of redundant data (which can make it a
# lot slower).
sql = """
WITH RECURSIVE links(chain_id) AS (
SELECT
DISTINCT origin_chain_id
FROM event_auth_chain_links WHERE %s
UNION
SELECT
target_chain_id
FROM event_auth_chain_links
INNER JOIN links ON (chain_id = origin_chain_id)
)
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM links
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
"""
while chains_to_fetch:
batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
chains_to_fetch.difference_update(batch2)
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)
links: Dict[int, List[Tuple[int, int, int]]] = {}
for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
links.setdefault(origin_chain_id, []).append(
(origin_sequence_number, target_chain_id, target_sequence_number)
)
chains_to_fetch.difference_update(links)
yield links
def _get_auth_chain_ids_txn(
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
) -> Set[str]:
@ -564,53 +579,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Now we look up all links for the chains we have, adding chains that
# are reachable from any event.
#
# This query is structured to first get all chain IDs reachable, and
# then pull out all links from those chains. This does pull out more
# rows than is strictly necessary, however there isn't a way of
# structuring the recursive part of query to pull out the links without
# also returning large quantities of redundant data (which can make it a
# lot slower).
sql = """
WITH RECURSIVE links(chain_id) AS (
SELECT
DISTINCT origin_chain_id
FROM event_auth_chain_links WHERE %s
UNION
SELECT
target_chain_id
FROM event_auth_chain_links
INNER JOIN links ON (chain_id = origin_chain_id)
)
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM links
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
"""
# (We need to take a copy of `seen_chains` as we want to mutate it in
# the loop)
chains_to_fetch = set(seen_chains)
while chains_to_fetch:
batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)
links: Dict[int, List[Tuple[int, int, int]]] = {}
for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
links.setdefault(origin_chain_id, []).append(
(origin_sequence_number, target_chain_id, target_sequence_number)
)
# (We need to take a copy of `seen_chains` as the function mutates it)
for links in self._get_chain_links(txn, set(seen_chains)):
for chains in set_to_chain:
for chain_id in links:
if chain_id not in chains:
@ -618,7 +589,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
_materialize(chain_id, chains[chain_id], links, chains)
chains_to_fetch.difference_update(chains)
seen_chains.update(chains)
# Now for each chain we figure out the maximum sequence number reachable

View file

@ -106,7 +106,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.stream import StreamWorkerStore
from synapse.types import JsonDict
from synapse.types import JsonDict, StrCollection
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@ -859,37 +859,86 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
return await self.db_pool.runInteraction("get_push_action_users_in_range", f)
def _get_receipts_by_room_txn(
self, txn: LoggingTransaction, user_id: str
def _get_receipts_for_room_and_threads_txn(
self,
txn: LoggingTransaction,
user_id: str,
room_ids: StrCollection,
thread_ids: StrCollection,
) -> Dict[str, _RoomReceipt]:
"""
Generate a map of room ID to the latest stream ordering that has been
read by the given user.
Get (private) read receipts for a user in each of the given room IDs
and thread IDs.
Args:
txn:
user_id: The user to fetch receipts for.
Note: The corresponding room ID for each thread must appear in
`room_ids` arg.
Returns:
A map including all rooms the user is in with a receipt. It maps
room IDs to _RoomReceipt instances
"""
receipt_types_clause, args = make_in_list_sql_clause(
receipt_types_clause, receipts_args = make_in_list_sql_clause(
self.database_engine,
"receipt_type",
(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
)
thread_ids_clause, thread_ids_args = make_in_list_sql_clause(
self.database_engine,
"thread_id",
thread_ids,
)
room_ids_clause, room_ids_args = make_in_list_sql_clause(
self.database_engine,
"room_id",
room_ids,
)
# We use the union of two (almost identical) queries here, the first to
# fetch the specific thread receipts and the second to fetch the
# unthreaded receipts.
#
# This SQL is optimized to use the indices we have on
# `receipts_linearized`.
#
# We compare room ID and thread IDs independently due to the above,
# which means that this query might return more rows than we need if the
# same thread ID appears across different rooms (e.g. 'main' thread ID).
# This doesn't cause any logic issues, and isn't a performance concern
# given this function generally gets called with only one room and
# thread ID.
sql = f"""
SELECT room_id, thread_id, MAX(stream_ordering)
FROM receipts_linearized
INNER JOIN events USING (room_id, event_id)
WHERE {receipt_types_clause}
AND {thread_ids_clause}
AND {room_ids_clause}
AND user_id = ?
GROUP BY room_id, thread_id
UNION ALL
SELECT room_id, thread_id, MAX(stream_ordering)
FROM receipts_linearized
INNER JOIN events USING (room_id, event_id)
WHERE {receipt_types_clause}
AND {room_ids_clause}
AND thread_id IS NULL
AND user_id = ?
GROUP BY room_id, thread_id
"""
args.extend((user_id,))
args = list(receipts_args)
args.extend(thread_ids_args)
args.extend(room_ids_args)
args.append(user_id)
args.extend(receipts_args)
args.extend(room_ids_args)
args.append(user_id)
txn.execute(sql, args)
result: Dict[str, _RoomReceipt] = {}
@ -925,12 +974,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
The list will have between 0~limit entries.
"""
receipts_by_room = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_receipts",
self._get_receipts_by_room_txn,
user_id=user_id,
)
def get_push_actions_txn(
txn: LoggingTransaction,
) -> List[Tuple[str, str, str, int, str, bool]]:
@ -952,6 +995,27 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
"get_unread_push_actions_for_user_in_range_http", get_push_actions_txn
)
room_ids = set()
thread_ids = []
for (
_,
room_id,
thread_id,
_,
_,
_,
) in push_actions:
room_ids.add(room_id)
thread_ids.append(thread_id)
receipts_by_room = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_receipts",
self._get_receipts_for_room_and_threads_txn,
user_id=user_id,
room_ids=room_ids,
thread_ids=thread_ids,
)
notifs = [
HttpPushAction(
event_id=event_id,
@ -998,12 +1062,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
The list will have between 0~limit entries.
"""
receipts_by_room = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_receipts",
self._get_receipts_by_room_txn,
user_id=user_id,
)
def get_push_actions_txn(
txn: LoggingTransaction,
) -> List[Tuple[str, str, str, int, str, bool, int]]:
@ -1026,6 +1084,28 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
"get_unread_push_actions_for_user_in_range_email", get_push_actions_txn
)
room_ids = set()
thread_ids = []
for (
_,
room_id,
thread_id,
_,
_,
_,
_,
) in push_actions:
room_ids.add(room_id)
thread_ids.append(thread_id)
receipts_by_room = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_receipts",
self._get_receipts_for_room_and_threads_txn,
user_id=user_id,
room_ids=room_ids,
thread_ids=thread_ids,
)
# Make a list of dicts from the two sets of results.
notifs = [
EmailPushAction(

View file

@ -2266,6 +2266,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
update_name="access_tokens_refresh_token_id_idx",
index_name="access_tokens_refresh_token_id_idx",
table="access_tokens",
columns=("refresh_token_id",),
)
self._ignore_unknown_session_error = (
hs.config.server.request_token_inhibit_3pid_errors
)

View file

@ -17,14 +17,16 @@
# [This file includes modifications made by New Vector Limited]
#
#
from typing import Collection, List, Optional
from typing import Collection, ContextManager, List, Optional
from unittest.mock import AsyncMock, Mock, patch
from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import Filtering
from synapse.api.filtering import FilterCollection, Filtering
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
@ -33,7 +35,7 @@ from synapse.handlers.sync import SyncConfig, SyncResult
from synapse.rest import admin
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
import tests.unittest
@ -258,13 +260,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Eve tries to join the room. We monkey patch the internal logic which selects
# the prev_events used when creating the join event, such that the ban does not
# precede the join.
mocked_get_prev_events = patch.object(
self.hs.get_datastores().main,
"get_prev_events_for_room",
new_callable=AsyncMock,
return_value=[last_room_creation_event_id],
)
with mocked_get_prev_events:
with self._patch_get_latest_events([last_room_creation_event_id]):
self.helper.join(room_id, eve, tok=eve_token)
# Eve makes a second, incremental sync.
@ -288,6 +284,365 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
)
self.assertEqual(eve_initial_sync_after_join.joined, [])
def test_state_includes_changes_on_forks(self) -> None:
"""State changes that happen on a fork of the DAG must be included in `state`
Given the following DAG:
E1
| S2
|
--|------|----
| |
E3 |
/
E4
... and a filter that means we only return 2 events, represented by the dashed
horizontal line: `S2` must be included in the `state` section.
"""
alice = self.register_user("alice", "password")
alice_tok = self.login(alice, "password")
alice_requester = create_requester(alice)
room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok)
# Do an initial sync as Alice to get a known starting point.
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester, generate_sync_config(alice)
)
)
last_room_creation_event_id = (
initial_sync_result.joined[0].timeline.events[-1].event_id
)
# Send a state event, and a regular event, both using the same prev ID
with self._patch_get_latest_events([last_room_creation_event_id]):
s2_event = self.helper.send_state(room_id, "s2", {}, tok=alice_tok)[
"event_id"
]
e3_event = self.helper.send(room_id, "e3", tok=alice_tok)["event_id"]
# Send a final event, joining the two branches of the dag
e4_event = self.helper.send(room_id, "e4", tok=alice_tok)["event_id"]
# do an incremental sync, with a filter that will ensure we only get two of
# the three new events.
incremental_sync = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester,
generate_sync_config(
alice,
filter_collection=FilterCollection(
self.hs, {"room": {"timeline": {"limit": 2}}}
),
),
since_token=initial_sync_result.next_batch,
)
)
# The state event should appear in the 'state' section of the response.
room_sync = incremental_sync.joined[0]
self.assertEqual(room_sync.room_id, room_id)
self.assertTrue(room_sync.timeline.limited)
self.assertEqual(
[e.event_id for e in room_sync.timeline.events],
[e3_event, e4_event],
)
self.assertEqual(
[e.event_id for e in room_sync.state.values()],
[s2_event],
)
def test_state_includes_changes_on_forks_when_events_excluded(self) -> None:
"""A variation on the previous test, but where one event is filtered
The DAG is the same as the previous test, but E4 is excluded by the filter.
E1
| S2
|
--|------|----
| |
E3 |
/
(E4)
"""
alice = self.register_user("alice", "password")
alice_tok = self.login(alice, "password")
alice_requester = create_requester(alice)
room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok)
# Do an initial sync as Alice to get a known starting point.
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester, generate_sync_config(alice)
)
)
last_room_creation_event_id = (
initial_sync_result.joined[0].timeline.events[-1].event_id
)
# Send a state event, and a regular event, both using the same prev ID
with self._patch_get_latest_events([last_room_creation_event_id]):
s2_event = self.helper.send_state(room_id, "s2", {}, tok=alice_tok)[
"event_id"
]
e3_event = self.helper.send(room_id, "e3", tok=alice_tok)["event_id"]
# Send a final event, joining the two branches of the dag
self.helper.send(room_id, "e4", type="not_a_normal_message", tok=alice_tok)[
"event_id"
]
# do an incremental sync, with a filter that will only return E3, excluding S2
# and E4.
incremental_sync = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester,
generate_sync_config(
alice,
filter_collection=FilterCollection(
self.hs,
{
"room": {
"timeline": {
"limit": 1,
"not_types": ["not_a_normal_message"],
}
}
},
),
),
since_token=initial_sync_result.next_batch,
)
)
# The state event should appear in the 'state' section of the response.
room_sync = incremental_sync.joined[0]
self.assertEqual(room_sync.room_id, room_id)
self.assertTrue(room_sync.timeline.limited)
self.assertEqual(
[e.event_id for e in room_sync.timeline.events],
[e3_event],
)
self.assertEqual(
[e.event_id for e in room_sync.state.values()],
[s2_event],
)
def test_state_includes_changes_on_ungappy_syncs(self) -> None:
"""Test `state` where the sync is not gappy.
We start with a DAG like this:
E1
| S2
|
--|---
|
E3
... and initialsync with `limit=1`, represented by the horizontal dashed line.
At this point, we do not expect S2 to appear in the response at all (since
it is excluded from the timeline by the `limit`, and the state is based on the
state after the most recent event before the sync token (E3), which doesn't
include S2.
Now more events arrive, and we do an incremental sync:
E1
| S2
|
E3 |
|
--|------|----
| |
E4 |
/
E5
This is the last chance for us to tell the client about S2, so it *must* be
included in the response.
"""
alice = self.register_user("alice", "password")
alice_tok = self.login(alice, "password")
alice_requester = create_requester(alice)
room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok)
# Do an initial sync to get a known starting point.
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester, generate_sync_config(alice)
)
)
last_room_creation_event_id = (
initial_sync_result.joined[0].timeline.events[-1].event_id
)
# Send a state event, and a regular event, both using the same prev ID
with self._patch_get_latest_events([last_room_creation_event_id]):
s2_event = self.helper.send_state(room_id, "s2", {}, tok=alice_tok)[
"event_id"
]
e3_event = self.helper.send(room_id, "e3", tok=alice_tok)["event_id"]
# Another initial sync, with limit=1
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester,
generate_sync_config(
alice,
filter_collection=FilterCollection(
self.hs, {"room": {"timeline": {"limit": 1}}}
),
),
)
)
room_sync = initial_sync_result.joined[0]
self.assertEqual(room_sync.room_id, room_id)
self.assertEqual(
[e.event_id for e in room_sync.timeline.events],
[e3_event],
)
self.assertNotIn(s2_event, [e.event_id for e in room_sync.state.values()])
# More events, E4 and E5
with self._patch_get_latest_events([e3_event]):
e4_event = self.helper.send(room_id, "e4", tok=alice_tok)["event_id"]
e5_event = self.helper.send(room_id, "e5", tok=alice_tok)["event_id"]
# Now incremental sync
incremental_sync = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester,
generate_sync_config(alice),
since_token=initial_sync_result.next_batch,
)
)
# The state event should appear in the 'state' section of the response.
room_sync = incremental_sync.joined[0]
self.assertEqual(room_sync.room_id, room_id)
self.assertFalse(room_sync.timeline.limited)
self.assertEqual(
[e.event_id for e in room_sync.timeline.events],
[e4_event, e5_event],
)
self.assertEqual(
[e.event_id for e in room_sync.state.values()],
[s2_event],
)
@parameterized.expand(
[
(False, False),
(True, False),
(False, True),
(True, True),
]
)
def test_archived_rooms_do_not_include_state_after_leave(
self, initial_sync: bool, empty_timeline: bool
) -> None:
"""If the user leaves the room, state changes that happen after they leave are not returned.
We try with both a zero and a normal timeline limit,
and we try both an initial sync and an incremental sync for both.
"""
if empty_timeline and not initial_sync:
# FIXME synapse doesn't return the room at all in this situation!
self.skipTest("Synapse does not correctly handle this case")
# Alice creates the room, and bob joins.
alice = self.register_user("alice", "password")
alice_tok = self.login(alice, "password")
bob = self.register_user("bob", "password")
bob_tok = self.login(bob, "password")
bob_requester = create_requester(bob)
room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok)
self.helper.join(room_id, bob, tok=bob_tok)
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
bob_requester, generate_sync_config(bob)
)
)
# Alice sends a message and a state
before_message_event = self.helper.send(room_id, "before", tok=alice_tok)[
"event_id"
]
before_state_event = self.helper.send_state(
room_id, "test_state", {"body": "before"}, tok=alice_tok
)["event_id"]
# Bob leaves
leave_event = self.helper.leave(room_id, bob, tok=bob_tok)["event_id"]
# Alice sends some more stuff
self.helper.send(room_id, "after", tok=alice_tok)["event_id"]
self.helper.send_state(room_id, "test_state", {"body": "after"}, tok=alice_tok)[
"event_id"
]
# And now, Bob resyncs.
filter_dict: JsonDict = {"room": {"include_leave": True}}
if empty_timeline:
filter_dict["room"]["timeline"] = {"limit": 0}
sync_room_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
bob_requester,
generate_sync_config(
bob, filter_collection=FilterCollection(self.hs, filter_dict)
),
since_token=None if initial_sync else initial_sync_result.next_batch,
)
).archived[0]
if empty_timeline:
# The timeline should be empty
self.assertEqual(sync_room_result.timeline.events, [])
# And the state should include the leave event...
self.assertEqual(
sync_room_result.state[("m.room.member", bob)].event_id, leave_event
)
# ... and the state change before he left.
self.assertEqual(
sync_room_result.state[("test_state", "")].event_id, before_state_event
)
else:
# The last three events in the timeline should be those leading up to the
# leave
self.assertEqual(
[e.event_id for e in sync_room_result.timeline.events[-3:]],
[before_message_event, before_state_event, leave_event],
)
# ... And the state should be empty
self.assertEqual(sync_room_result.state, {})
def _patch_get_latest_events(self, latest_events: List[str]) -> ContextManager:
"""Monkey-patch `get_prev_events_for_room`
Returns a context manager which will replace the implementation of
`get_prev_events_for_room` with one which returns `latest_events`.
"""
return patch.object(
self.hs.get_datastores().main,
"get_prev_events_for_room",
new_callable=AsyncMock,
return_value=latest_events,
)
def test_call_invite_in_public_room_not_returned(self) -> None:
user = self.register_user("alice", "password")
tok = self.login(user, "password")
@ -401,14 +756,26 @@ _request_key = 0
def generate_sync_config(
user_id: str, device_id: Optional[str] = "device_id"
user_id: str,
device_id: Optional[str] = "device_id",
filter_collection: Optional[FilterCollection] = None,
) -> SyncConfig:
"""Generate a sync config (with a unique request key)."""
"""Generate a sync config (with a unique request key).
Args:
user_id: user who is syncing.
device_id: device that is syncing. Defaults to "device_id".
filter_collection: filter to apply. Defaults to the default filter (ie,
return everything, with a default limit)
"""
if filter_collection is None:
filter_collection = Filtering(Mock()).DEFAULT_FILTER_COLLECTION
global _request_key
_request_key += 1
return SyncConfig(
user=UserID.from_string(user_id),
filter_collection=Filtering(Mock()).DEFAULT_FILTER_COLLECTION,
filter_collection=filter_collection,
is_guest=False,
request_key=("request_key", _request_key),
device_id=device_id,

View file

@ -170,8 +170,8 @@ class RestHelper:
targ: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
tok: Optional[str] = None,
) -> None:
self.change_membership(
) -> JsonDict:
return self.change_membership(
room=room,
src=src,
targ=targ,
@ -189,8 +189,8 @@ class RestHelper:
appservice_user_id: Optional[str] = None,
expect_errcode: Optional[Codes] = None,
expect_additional_fields: Optional[dict] = None,
) -> None:
self.change_membership(
) -> JsonDict:
return self.change_membership(
room=room,
src=user,
targ=user,
@ -242,8 +242,8 @@ class RestHelper:
user: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
tok: Optional[str] = None,
) -> None:
self.change_membership(
) -> JsonDict:
return self.change_membership(
room=room,
src=user,
targ=user,
@ -282,7 +282,7 @@ class RestHelper:
expect_code: int = HTTPStatus.OK,
expect_errcode: Optional[str] = None,
expect_additional_fields: Optional[dict] = None,
) -> None:
) -> JsonDict:
"""
Send a membership state event into a room.
@ -298,6 +298,9 @@ class RestHelper:
using an application service access token in `tok`.
expect_code: The expected HTTP response code
expect_errcode: The expected Matrix error code
Returns:
The JSON response
"""
temp_id = self.auth_user_id
self.auth_user_id = src
@ -356,6 +359,7 @@ class RestHelper:
)
self.auth_user_id = temp_id
return channel.json_body
def send(
self,