Implemented thread support for backfills

This commit is contained in:
Erik Johnston 2018-11-13 14:49:49 +00:00
parent c67953748d
commit 08395c7f89
3 changed files with 64 additions and 12 deletions

View file

@ -691,7 +691,12 @@ class FederationHandler(BaseHandler):
# Don't bother processing events we already have. # Don't bother processing events we already have.
seen_events = yield self.store.have_events_in_timeline( seen_events = yield self.store.have_events_in_timeline(
set(e.event_id for e in events) set(
itertools.chain.from_iterable(
itertools.chain([e.event_id], e.prev_event_ids(),)
for e in events
)
)
) )
events = [e for e in events if e.event_id not in seen_events] events = [e for e in events if e.event_id not in seen_events]
@ -706,7 +711,7 @@ class FederationHandler(BaseHandler):
edges = [ edges = [
ev.event_id ev.event_id
for ev in events for ev in events
if set(ev.prev_event_ids()) - event_ids if set(ev.prev_event_ids()) - event_ids - seen_events
] ]
logger.info( logger.info(
@ -740,18 +745,26 @@ class FederationHandler(BaseHandler):
}) })
missing_auth = required_auth - set(auth_events) missing_auth = required_auth - set(auth_events)
failed_to_fetch = set() failed_to_fetch = set()
not_in_db = set()
# Try and fetch any missing auth events from both DB and remote servers. # Try and fetch any missing auth events from both DB and remote servers.
# We repeatedly do this until we stop finding new auth events. # We repeatedly do this until we stop finding new auth events.
while missing_auth - failed_to_fetch: while missing_auth - failed_to_fetch:
logger.info("Missing auth for backfill: %r", missing_auth) logger.info("Missing auth for backfill: %r", missing_auth)
ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
auth_events.update(ret_events)
required_auth.update( to_fetch_from_db = missing_auth - failed_to_fetch
a_id for event in ret_events.values() for a_id in event.auth_event_ids() while to_fetch_from_db - not_in_db:
) ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
missing_auth = required_auth - set(auth_events) auth_events.update(ret_events)
required_auth.update(
a_id
for event in ret_events.values()
for a_id in event.auth_event_ids()
)
missing_auth = required_auth - set(auth_events)
to_fetch_from_db = required_auth - set(auth_events) - not_in_db
if missing_auth - failed_to_fetch: if missing_auth - failed_to_fetch:
logger.info( logger.info(
@ -820,6 +833,25 @@ class FederationHandler(BaseHandler):
events.sort(key=lambda e: e.depth) events.sort(key=lambda e: e.depth)
event_id_to_thread = {}
event_to_parents = {}
for event in reversed(events):
threads = yield self.store.get_threads_for_backfill_event(event.event_id)
parents = event_to_parents.get(event.event_id, [])
for p in parents:
t = event_id_to_thread.get(p)
if t is not None:
threads.append(t)
if threads:
thread_id = min(threads)
else:
thread_id = 0
event_id_to_thread[event.event_id] = thread_id
for c in event.prev_event_ids():
event_to_parents.setdefault(c, set()).add(event.event_id)
for event in events: for event in events:
if event in events_to_state: if event in events_to_state:
continue continue
@ -829,6 +861,7 @@ class FederationHandler(BaseHandler):
# TODO: We can probably do something more clever here. # TODO: We can probably do something more clever here.
yield self._handle_new_event( yield self._handle_new_event(
dest, event, backfilled=True, dest, event, backfilled=True,
thread_id=event_id_to_thread[event.event_id],
) )
defer.returnValue(events) defer.returnValue(events)
@ -838,12 +871,13 @@ class FederationHandler(BaseHandler):
"""Checks the database to see if we should backfill before paginating, """Checks the database to see if we should backfill before paginating,
and if so do. and if so do.
""" """
logger.info("Backfilling")
extremities = yield self.store.get_oldest_events_with_depth_in_room( extremities = yield self.store.get_oldest_events_with_depth_in_room(
room_id room_id
) )
if not extremities: if not extremities:
logger.debug("Not backfilling as no extremeties found.") logger.info("Not backfilling as no extremeties found.")
return return
# Check if we reached a point where we should start backfilling. # Check if we reached a point where we should start backfilling.
@ -858,7 +892,7 @@ class FederationHandler(BaseHandler):
extremities = dict(sorted_extremeties_tuple[:5]) extremities = dict(sorted_extremeties_tuple[:5])
if current_depth > max_depth: if current_depth > max_depth:
logger.debug( logger.info(
"Not backfilling as we don't need to. %d < %d", "Not backfilling as we don't need to. %d < %d",
max_depth, current_depth, max_depth, current_depth,
) )

View file

@ -28,7 +28,6 @@ from synapse.events.utils import (
format_event_raw, format_event_raw,
serialize_event, serialize_event,
) )
from synapse.events import FrozenEvent
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.sync import SyncConfig from synapse.handlers.sync import SyncConfig
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
@ -246,7 +245,8 @@ class SyncRestServlet(RestServlet):
} }
@staticmethod @staticmethod
def encode_joined(rooms, time_now, token_id, event_fields, event_formatter, exclude_threaded): def encode_joined(rooms, time_now, token_id, event_fields, event_formatter,
exclude_threaded):
""" """
Encode the joined rooms in a sync result Encode the joined rooms in a sync result

View file

@ -34,6 +34,24 @@ logger = logging.getLogger(__name__)
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
SQLBaseStore): SQLBaseStore):
def get_threads_for_backfill_event(self, event_id):
def _get_thread_for_backfill_event_txn(txn):
sql = """
SELECT thread_id
FROM event_edges
INNER JOIN events USING (event_id)
WHERE prev_event_id = ?
"""
txn.execute(sql, (event_id,))
return [thread_id for thread_id, in txn]
return self.runInteraction(
"get_thread_for_backfill_event",
_get_thread_for_backfill_event_txn,
)
def get_auth_chain(self, event_ids, include_given=False): def get_auth_chain(self, event_ids, include_given=False):
"""Get auth events for given event_ids. The events *must* be state events. """Get auth events for given event_ids. The events *must* be state events.