recursively fetch redactions

This commit is contained in:
Richard van der Hoff 2019-07-24 16:44:10 +01:00
parent e6a6c4fbab
commit 448bcfd0f9

View file

@ -17,7 +17,6 @@ from __future__ import division
import itertools import itertools
import logging import logging
import operator
from collections import namedtuple from collections import namedtuple
from canonicaljson import json from canonicaljson import json
@ -30,12 +29,7 @@ from synapse.api.room_versions import EventFormatVersions
from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401 from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.logging.context import ( from synapse.logging.context import LoggingContext, PreserveLoggingContext
LoggingContext,
PreserveLoggingContext,
make_deferred_yieldable,
run_in_background,
)
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util import batch_iter from synapse.util import batch_iter
@ -468,39 +462,49 @@ class EventsWorkerStore(SQLBaseStore):
Returns: Returns:
Deferred[Dict[str, _EventCacheEntry]]: Deferred[Dict[str, _EventCacheEntry]]:
map from event id to result. map from event id to result. May return extra events which
weren't asked for.
""" """
if not event_ids: fetched_events = {}
return {} events_to_fetch = event_ids
row_map = yield self._enqueue_events(event_ids) while events_to_fetch:
row_map = yield self._enqueue_events(events_to_fetch)
rows = (row_map.get(event_id) for event_id in event_ids) # we need to recursively fetch any redactions of those events
redaction_ids = set()
for event_id in events_to_fetch:
row = row_map.get(event_id)
fetched_events[event_id] = row
if row:
redaction_ids.update(row["redactions"])
# filter out absent rows events_to_fetch = redaction_ids.difference(fetched_events.keys())
rows = filter(operator.truth, rows) if events_to_fetch:
logger.debug("Also fetching redaction events %s", events_to_fetch)
if not allow_rejected: result_map = {}
rows = (r for r in rows if r["rejected_reason"] is None) for event_id, row in fetched_events.items():
if not row:
continue
assert row["event_id"] == event_id
res = yield make_deferred_yieldable( rejected_reason = row["rejected_reason"]
defer.gatherResults(
[ if not allow_rejected and rejected_reason:
run_in_background( continue
self._get_event_from_row,
cache_entry = yield self._get_event_from_row(
row["internal_metadata"], row["internal_metadata"],
row["json"], row["json"],
row["redactions"], row["redactions"],
rejected_reason=row["rejected_reason"], rejected_reason=row["rejected_reason"],
format_version=row["format_version"], format_version=row["format_version"],
) )
for row in rows
],
consumeErrors=True,
)
)
return {e.event.event_id: e for e in res if e} result_map[event_id] = cache_entry
return result_map
@defer.inlineCallbacks @defer.inlineCallbacks
def _enqueue_events(self, events): def _enqueue_events(self, events):