mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-29 15:39:00 +03:00
Add basic editing support
This commit is contained in:
parent
5c39d262c0
commit
d46aab3fa8
4 changed files with 167 additions and 15 deletions
|
@ -346,7 +346,7 @@ class EventClientSerializer(object):
|
||||||
defer.returnValue(event)
|
defer.returnValue(event)
|
||||||
|
|
||||||
event_id = event.event_id
|
event_id = event.event_id
|
||||||
event = serialize_event(event, time_now, **kwargs)
|
serialized_event = serialize_event(event, time_now, **kwargs)
|
||||||
|
|
||||||
# If MSC1849 is enabled then we need to look if thre are any relations
|
# If MSC1849 is enabled then we need to look if thre are any relations
|
||||||
# we need to bundle in with the event
|
# we need to bundle in with the event
|
||||||
|
@ -359,14 +359,36 @@ class EventClientSerializer(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
if annotations.chunk:
|
if annotations.chunk:
|
||||||
r = event["unsigned"].setdefault("m.relations", {})
|
r = serialized_event["unsigned"].setdefault("m.relations", {})
|
||||||
r[RelationTypes.ANNOTATION] = annotations.to_dict()
|
r[RelationTypes.ANNOTATION] = annotations.to_dict()
|
||||||
|
|
||||||
if references.chunk:
|
if references.chunk:
|
||||||
r = event["unsigned"].setdefault("m.relations", {})
|
r = serialized_event["unsigned"].setdefault("m.relations", {})
|
||||||
r[RelationTypes.REFERENCES] = references.to_dict()
|
r[RelationTypes.REFERENCES] = references.to_dict()
|
||||||
|
|
||||||
defer.returnValue(event)
|
edit = None
|
||||||
|
if event.type == EventTypes.Message:
|
||||||
|
edit = yield self.store.get_applicable_edit(
|
||||||
|
event.event_id, event.type, event.sender,
|
||||||
|
)
|
||||||
|
|
||||||
|
if edit:
|
||||||
|
# If there is an edit replace the content, preserving existing
|
||||||
|
# relations.
|
||||||
|
|
||||||
|
relations = event.content.get("m.relates_to")
|
||||||
|
serialized_event["content"] = edit.content.get("m.new_content", {})
|
||||||
|
if relations:
|
||||||
|
serialized_event["content"]["m.relates_to"] = relations
|
||||||
|
else:
|
||||||
|
serialized_event["content"].pop("m.relates_to", None)
|
||||||
|
|
||||||
|
r = serialized_event["unsigned"].setdefault("m.relations", {})
|
||||||
|
r[RelationTypes.REPLACES] = {
|
||||||
|
"event_id": edit.event_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
defer.returnValue(serialized_event)
|
||||||
|
|
||||||
def serialize_events(self, events, time_now, **kwargs):
|
def serialize_events(self, events, time_now, **kwargs):
|
||||||
"""Serializes multiple events.
|
"""Serializes multiple events.
|
||||||
|
|
|
@ -143,3 +143,4 @@ class SlavedEventStore(EventFederationWorkerStore,
|
||||||
if relates_to:
|
if relates_to:
|
||||||
self.get_relations_for_event.invalidate_many((relates_to,))
|
self.get_relations_for_event.invalidate_many((relates_to,))
|
||||||
self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
|
self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
|
||||||
|
self.get_applicable_edit.invalidate_many((relates_to,))
|
||||||
|
|
|
@ -17,11 +17,13 @@ import logging
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from synapse.api.constants import RelationTypes
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.constants import EventTypes, RelationTypes
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.stream import generate_pagination_where_clause
|
from synapse.storage.stream import generate_pagination_where_clause
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -312,6 +314,59 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||||
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
|
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cachedInlineCallbacks(tree=True)
|
||||||
|
def get_applicable_edit(self, event_id, event_type, sender):
|
||||||
|
"""Get the most recent edit (if any) that has happened for the given
|
||||||
|
event.
|
||||||
|
|
||||||
|
Correctly handles checking whether edits were allowed to happen.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_id (str): The original event ID
|
||||||
|
event_type (str): The original event type
|
||||||
|
sender (str): The original event sender
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[EventBase|None]: Returns the most recent edit, if any.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# We only allow edits for `m.room.message` events that have the same sender
|
||||||
|
# and event type. We can't assert these things during regular event auth so
|
||||||
|
# we have to do the post hoc.
|
||||||
|
|
||||||
|
if event_type != EventTypes.Message:
|
||||||
|
return
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
SELECT event_id, origin_server_ts FROM events
|
||||||
|
INNER JOIN event_relations USING (event_id)
|
||||||
|
WHERE
|
||||||
|
relates_to_id = ?
|
||||||
|
AND relation_type = ?
|
||||||
|
AND type = ?
|
||||||
|
AND sender = ?
|
||||||
|
ORDER by origin_server_ts DESC, event_id DESC
|
||||||
|
LIMIT 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_applicable_edit_txn(txn):
|
||||||
|
txn.execute(
|
||||||
|
sql, (event_id, RelationTypes.REPLACES, event_type, sender)
|
||||||
|
)
|
||||||
|
row = txn.fetchone()
|
||||||
|
if row:
|
||||||
|
return row[0]
|
||||||
|
|
||||||
|
edit_id = yield self.runInteraction(
|
||||||
|
"get_applicable_edit", _get_applicable_edit_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
if not edit_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
edit_event = yield self.get_event(edit_id, allow_none=True)
|
||||||
|
defer.returnValue(edit_event)
|
||||||
|
|
||||||
|
|
||||||
class RelationsStore(RelationsWorkerStore):
|
class RelationsStore(RelationsWorkerStore):
|
||||||
def _handle_event_relations(self, txn, event):
|
def _handle_event_relations(self, txn, event):
|
||||||
|
@ -357,3 +412,4 @@ class RelationsStore(RelationsWorkerStore):
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
|
self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
|
||||||
)
|
)
|
||||||
|
txn.call_after(self.get_applicable_edit.invalidate_many, (parent_id,))
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
|
import json
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
@ -102,11 +103,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||||
# relation event we sent above.
|
# relation event we sent above.
|
||||||
self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
|
self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
|
||||||
self.assert_dict(
|
self.assert_dict(
|
||||||
{
|
{"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"},
|
||||||
"event_id": annotation_id,
|
|
||||||
"sender": self.user_id,
|
|
||||||
"type": "m.reaction",
|
|
||||||
},
|
|
||||||
channel.json_body["chunk"][0],
|
channel.json_body["chunk"][0],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -330,8 +327,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEquals(200, channel.code, channel.json_body)
|
self.assertEquals(200, channel.code, channel.json_body)
|
||||||
|
|
||||||
self.maxDiff = None
|
|
||||||
|
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
channel.json_body["unsigned"].get("m.relations"),
|
channel.json_body["unsigned"].get("m.relations"),
|
||||||
{
|
{
|
||||||
|
@ -347,7 +342,84 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _send_relation(self, relation_type, event_type, key=None):
|
def test_edit(self):
|
||||||
|
"""Test that a simple edit works.
|
||||||
|
"""
|
||||||
|
|
||||||
|
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
|
||||||
|
channel = self._send_relation(
|
||||||
|
RelationTypes.REPLACES,
|
||||||
|
"m.room.message",
|
||||||
|
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
|
||||||
|
)
|
||||||
|
self.assertEquals(200, channel.code, channel.json_body)
|
||||||
|
|
||||||
|
edit_event_id = channel.json_body["event_id"]
|
||||||
|
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"GET", "/rooms/%s/event/%s" % (self.room, self.parent_id)
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
self.assertEquals(200, channel.code, channel.json_body)
|
||||||
|
|
||||||
|
self.assertEquals(channel.json_body["content"], new_body)
|
||||||
|
|
||||||
|
self.assertEquals(
|
||||||
|
channel.json_body["unsigned"].get("m.relations"),
|
||||||
|
{RelationTypes.REPLACES: {"event_id": edit_event_id}},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_multi_edit(self):
|
||||||
|
"""Test that multiple edits, including attempts by people who
|
||||||
|
shouldn't be allowed, are correctly handled.
|
||||||
|
"""
|
||||||
|
|
||||||
|
channel = self._send_relation(
|
||||||
|
RelationTypes.REPLACES,
|
||||||
|
"m.room.message",
|
||||||
|
content={
|
||||||
|
"msgtype": "m.text",
|
||||||
|
"body": "Wibble",
|
||||||
|
"m.new_content": {"msgtype": "m.text", "body": "First edit"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEquals(200, channel.code, channel.json_body)
|
||||||
|
|
||||||
|
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
|
||||||
|
channel = self._send_relation(
|
||||||
|
RelationTypes.REPLACES,
|
||||||
|
"m.room.message",
|
||||||
|
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
|
||||||
|
)
|
||||||
|
self.assertEquals(200, channel.code, channel.json_body)
|
||||||
|
|
||||||
|
edit_event_id = channel.json_body["event_id"]
|
||||||
|
|
||||||
|
channel = self._send_relation(
|
||||||
|
RelationTypes.REPLACES,
|
||||||
|
"m.room.message.WRONG_TYPE",
|
||||||
|
content={
|
||||||
|
"msgtype": "m.text",
|
||||||
|
"body": "Wibble",
|
||||||
|
"m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEquals(200, channel.code, channel.json_body)
|
||||||
|
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"GET", "/rooms/%s/event/%s" % (self.room, self.parent_id)
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
self.assertEquals(200, channel.code, channel.json_body)
|
||||||
|
|
||||||
|
self.assertEquals(channel.json_body["content"], new_body)
|
||||||
|
|
||||||
|
self.assertEquals(
|
||||||
|
channel.json_body["unsigned"].get("m.relations"),
|
||||||
|
{RelationTypes.REPLACES: {"event_id": edit_event_id}},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _send_relation(self, relation_type, event_type, key=None, content={}):
|
||||||
"""Helper function to send a relation pointing at `self.parent_id`
|
"""Helper function to send a relation pointing at `self.parent_id`
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -355,6 +427,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||||
event_type (str): The type of the event to create
|
event_type (str): The type of the event to create
|
||||||
key (str|None): The aggregation key used for m.annotation relation
|
key (str|None): The aggregation key used for m.annotation relation
|
||||||
type.
|
type.
|
||||||
|
content(dict|None): The content of the created event.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
FakeChannel
|
FakeChannel
|
||||||
|
@ -367,7 +440,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||||
"POST",
|
"POST",
|
||||||
"/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
|
"/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
|
||||||
% (self.room, self.parent_id, relation_type, event_type, query),
|
% (self.room, self.parent_id, relation_type, event_type, query),
|
||||||
b"{}",
|
json.dumps(content).encode("utf-8"),
|
||||||
)
|
)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
return channel
|
return channel
|
||||||
|
|
Loading…
Reference in a new issue