Consolidate logic for parsing relations. (#12693)

Parse the `m.relates_to` event content field (which describes relations)
in a single place, this is used during:

* Event persistence.
* Validation of the Client-Server API.
* Fetching bundled aggregations.
* Processing of push rules.

Each of these separately implement the logic and each made slightly
different assumptions about what was valid. Some had minor / potential
bugs.
This commit is contained in:
Patrick Cloke 2022-05-16 08:42:45 -04:00 committed by GitHub
parent cde8af9a49
commit 86a515ccbf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 98 additions and 61 deletions

1
changelog.d/12693.misc Normal file
View file

@ -0,0 +1 @@
Consolidate parsing of relation information from events.

View file

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import abc import abc
import collections.abc
import os import os
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -32,9 +33,11 @@ from typing import (
overload, overload,
) )
import attr
from typing_extensions import Literal from typing_extensions import Literal
from unpaddedbase64 import encode_base64 from unpaddedbase64 import encode_base64
from synapse.api.constants import RelationTypes
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
from synapse.types import JsonDict, RoomStreamToken from synapse.types import JsonDict, RoomStreamToken
from synapse.util.caches import intern_dict from synapse.util.caches import intern_dict
@ -615,3 +618,45 @@ def make_event_from_dict(
return event_type( return event_type(
event_dict, room_version, internal_metadata_dict or {}, rejected_reason event_dict, room_version, internal_metadata_dict or {}, rejected_reason
) )
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _EventRelation:
# The target event of the relation.
parent_id: str
# The relation type.
rel_type: str
# The aggregation key. Will be None if the rel_type is not m.annotation or is
# not a string.
aggregation_key: Optional[str]
def relation_from_event(event: EventBase) -> Optional[_EventRelation]:
"""
Attempt to parse relation information an event.
Returns:
The event relation information, if it is valid. None, otherwise.
"""
relation = event.content.get("m.relates_to")
if not relation or not isinstance(relation, collections.abc.Mapping):
# No relation information.
return None
# Relations must have a type and parent event ID.
rel_type = relation.get("rel_type")
if not isinstance(rel_type, str):
return None
parent_id = relation.get("event_id")
if not isinstance(parent_id, str):
return None
# Annotations have a key field.
aggregation_key = None
if rel_type == RelationTypes.ANNOTATION:
aggregation_key = relation.get("key")
if not isinstance(aggregation_key, str):
aggregation_key = None
return _EventRelation(parent_id, rel_type, aggregation_key)

View file

@ -44,7 +44,7 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.api.urls import ConsentURIBuilder from synapse.api.urls import ConsentURIBuilder
from synapse.event_auth import validate_event_for_room_version from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase from synapse.events import EventBase, relation_from_event
from synapse.events.builder import EventBuilder from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
@ -1060,20 +1060,11 @@ class EventCreationHandler:
SynapseError if the event is invalid. SynapseError if the event is invalid.
""" """
relation = event.content.get("m.relates_to") relation = relation_from_event(event)
if not relation: if not relation:
return return
relation_type = relation.get("rel_type") parent_event = await self.store.get_event(relation.parent_id, allow_none=True)
if not relation_type:
return
# Ensure the parent is real.
relates_to = relation.get("event_id")
if not relates_to:
return
parent_event = await self.store.get_event(relates_to, allow_none=True)
if parent_event: if parent_event:
# And in the same room. # And in the same room.
if parent_event.room_id != event.room_id: if parent_event.room_id != event.room_id:
@ -1082,28 +1073,31 @@ class EventCreationHandler:
else: else:
# There must be some reason that the client knows the event exists, # There must be some reason that the client knows the event exists,
# see if there are existing relations. If so, assume everything is fine. # see if there are existing relations. If so, assume everything is fine.
if not await self.store.event_is_target_of_relation(relates_to): if not await self.store.event_is_target_of_relation(relation.parent_id):
# Otherwise, the client can't know about the parent event! # Otherwise, the client can't know about the parent event!
raise SynapseError(400, "Can't send relation to unknown event") raise SynapseError(400, "Can't send relation to unknown event")
# If this event is an annotation then we check that that the sender # If this event is an annotation then we check that that the sender
# can't annotate the same way twice (e.g. stops users from liking an # can't annotate the same way twice (e.g. stops users from liking an
# event multiple times). # event multiple times).
if relation_type == RelationTypes.ANNOTATION: if relation.rel_type == RelationTypes.ANNOTATION:
aggregation_key = relation["key"] aggregation_key = relation.aggregation_key
if aggregation_key is None:
raise SynapseError(400, "Missing aggregation key")
if len(aggregation_key) > 500: if len(aggregation_key) > 500:
raise SynapseError(400, "Aggregation key is too long") raise SynapseError(400, "Aggregation key is too long")
already_exists = await self.store.has_user_annotated_event( already_exists = await self.store.has_user_annotated_event(
relates_to, event.type, aggregation_key, event.sender relation.parent_id, event.type, aggregation_key, event.sender
) )
if already_exists: if already_exists:
raise SynapseError(400, "Can't send same reaction twice") raise SynapseError(400, "Can't send same reaction twice")
# Don't attempt to start a thread if the parent event is a relation. # Don't attempt to start a thread if the parent event is a relation.
elif relation_type == RelationTypes.THREAD: elif relation.rel_type == RelationTypes.THREAD:
if await self.store.event_includes_relation(relates_to): if await self.store.event_includes_relation(relation.parent_id):
raise SynapseError( raise SynapseError(
400, "Cannot start threads from an event with a relation" 400, "Cannot start threads from an event with a relation"
) )

View file

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections.abc
import logging import logging
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -28,7 +27,7 @@ import attr
from synapse.api.constants import RelationTypes from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events import EventBase from synapse.events import EventBase, relation_from_event
from synapse.storage.databases.main.relations import _RelatedEvent from synapse.storage.databases.main.relations import _RelatedEvent
from synapse.types import JsonDict, Requester, StreamToken, UserID from synapse.types import JsonDict, Requester, StreamToken, UserID
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -373,20 +372,21 @@ class RelationsHandler:
if event.is_state(): if event.is_state():
continue continue
relates_to = event.content.get("m.relates_to") relates_to = relation_from_event(event)
relation_type = None if relates_to:
if isinstance(relates_to, collections.abc.Mapping):
relation_type = relates_to.get("rel_type")
# An event which is a replacement (ie edit) or annotation (ie, # An event which is a replacement (ie edit) or annotation (ie,
# reaction) may not have any other event related to it. # reaction) may not have any other event related to it.
if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): if relates_to.rel_type in (
RelationTypes.ANNOTATION,
RelationTypes.REPLACE,
):
continue continue
# Track the event's relation information for later.
relations_by_id[event.event_id] = relates_to.rel_type
# The event should get bundled aggregations. # The event should get bundled aggregations.
events_by_id[event.event_id] = event events_by_id[event.event_id] = event
# Track the event's relation information for later.
if isinstance(relation_type, str):
relations_by_id[event.event_id] = relation_type
# event ID -> bundled aggregation in non-serialized form. # event ID -> bundled aggregation in non-serialized form.
results: Dict[str, BundledAggregations] = {} results: Dict[str, BundledAggregations] = {}

View file

@ -21,7 +21,7 @@ from prometheus_client import Counter
from synapse.api.constants import EventTypes, Membership, RelationTypes from synapse.api.constants import EventTypes, Membership, RelationTypes
from synapse.event_auth import get_user_power_level from synapse.event_auth import get_user_power_level
from synapse.events import EventBase from synapse.events import EventBase, relation_from_event
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY from synapse.state import POWER_KEY
from synapse.storage.databases.main.roommember import EventIdMembership from synapse.storage.databases.main.roommember import EventIdMembership
@ -78,8 +78,8 @@ def _should_count_as_unread(event: EventBase, context: EventContext) -> bool:
return False return False
# Exclude edits. # Exclude edits.
relates_to = event.content.get("m.relates_to", {}) relates_to = relation_from_event(event)
if relates_to.get("rel_type") == RelationTypes.REPLACE: if relates_to and relates_to.rel_type == RelationTypes.REPLACE:
return False return False
# Mark events that have a non-empty string body as unread. # Mark events that have a non-empty string body as unread.

View file

@ -36,8 +36,8 @@ from prometheus_client import Counter
import synapse.metrics import synapse.metrics
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase # noqa: F401 from synapse.events import EventBase, relation_from_event
from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.snapshot import EventContext
from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import ( from synapse.storage.database import (
DatabasePool, DatabasePool,
@ -1807,52 +1807,45 @@ class PersistEventsStore:
txn: The current database transaction. txn: The current database transaction.
event: The event which might have relations. event: The event which might have relations.
""" """
relation = event.content.get("m.relates_to") relation = relation_from_event(event)
if not relation: if not relation:
# No relations # No relation, nothing to do.
return return
# Relations must have a type and parent event ID.
rel_type = relation.get("rel_type")
if not isinstance(rel_type, str):
return
parent_id = relation.get("event_id")
if not isinstance(parent_id, str):
return
# Annotations have a key field.
aggregation_key = None
if rel_type == RelationTypes.ANNOTATION:
aggregation_key = relation.get("key")
self.db_pool.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="event_relations", table="event_relations",
values={ values={
"event_id": event.event_id, "event_id": event.event_id,
"relates_to_id": parent_id, "relates_to_id": relation.parent_id,
"relation_type": rel_type, "relation_type": relation.rel_type,
"aggregation_key": aggregation_key, "aggregation_key": relation.aggregation_key,
}, },
) )
txn.call_after(self.store.get_relations_for_event.invalidate, (parent_id,))
txn.call_after( txn.call_after(
self.store.get_aggregation_groups_for_event.invalidate, (parent_id,) self.store.get_relations_for_event.invalidate, (relation.parent_id,)
)
txn.call_after(
self.store.get_aggregation_groups_for_event.invalidate,
(relation.parent_id,),
) )
if rel_type == RelationTypes.REPLACE: if relation.rel_type == RelationTypes.REPLACE:
txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) txn.call_after(
self.store.get_applicable_edit.invalidate, (relation.parent_id,)
)
if rel_type == RelationTypes.THREAD: if relation.rel_type == RelationTypes.THREAD:
txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) txn.call_after(
self.store.get_thread_summary.invalidate, (relation.parent_id,)
)
# It should be safe to only invalidate the cache if the user has not # It should be safe to only invalidate the cache if the user has not
# previously participated in the thread, but that's difficult (and # previously participated in the thread, but that's difficult (and
# potentially error-prone) so it is always invalidated. # potentially error-prone) so it is always invalidated.
txn.call_after( txn.call_after(
self.store.get_thread_participated.invalidate, self.store.get_thread_participated.invalidate,
(parent_id, event.sender), (relation.parent_id, event.sender),
) )
def _handle_insertion_event( def _handle_insertion_event(

View file

@ -656,12 +656,13 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
self._check_unread_count(3) self._check_unread_count(3)
# Check that custom events with a body increase the unread counter. # Check that custom events with a body increase the unread counter.
self.helper.send_event( result = self.helper.send_event(
self.room_id, self.room_id,
"org.matrix.custom_type", "org.matrix.custom_type",
{"body": "hello"}, {"body": "hello"},
tok=self.tok2, tok=self.tok2,
) )
event_id = result["event_id"]
self._check_unread_count(4) self._check_unread_count(4)
# Check that edits don't increase the unread counter. # Check that edits don't increase the unread counter.
@ -671,7 +672,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
content={ content={
"body": "hello", "body": "hello",
"msgtype": "m.text", "msgtype": "m.text",
"m.relates_to": {"rel_type": RelationTypes.REPLACE}, "m.relates_to": {
"rel_type": RelationTypes.REPLACE,
"event_id": event_id,
},
}, },
tok=self.tok2, tok=self.tok2,
) )