Store and fetch thread IDs

This commit is contained in:
Erik Johnston 2018-11-12 15:44:22 +00:00
parent dc59ad5334
commit dfa830e61a
5 changed files with 46 additions and 14 deletions

View file

@ -74,6 +74,7 @@ class EventContext(object):
"delta_ids",
"prev_state_events",
"app_service",
"thread_id",
"_current_state_ids",
"_prev_state_ids",
"_prev_state_id",
@ -89,8 +90,9 @@ class EventContext(object):
@staticmethod
def with_state(state_group, current_state_ids, prev_state_ids,
prev_group=None, delta_ids=None):
thread_id, prev_group=None, delta_ids=None):
context = EventContext()
context.thread_id = thread_id
# The current state including the current event
context._current_state_ids = current_state_ids
@ -141,7 +143,8 @@ class EventContext(object):
"prev_group": self.prev_group,
"delta_ids": _encode_state_dict(self.delta_ids),
"prev_state_events": self.prev_state_events,
"app_service_id": self.app_service.id if self.app_service else None
"app_service_id": self.app_service.id if self.app_service else None,
"thread_id": self.thread_id,
})
@staticmethod
@ -158,6 +161,8 @@ class EventContext(object):
"""
context = EventContext()
context.thread_id = input["thread_input"]
# We use the state_group and prev_state_id stuff to pull the
# current_state_ids out of the DB and construct prev_state_ids.
context._prev_state_id = input["prev_state_id"]

View file

@ -18,6 +18,7 @@
import itertools
import logging
import random
import six
from six import iteritems, itervalues
@ -135,7 +136,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def on_receive_pdu(
self, origin, pdu, sent_to_us_directly=False,
self, origin, pdu, sent_to_us_directly=False, thread_id=None,
):
""" Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
@ -222,6 +223,10 @@ class FederationHandler(BaseHandler):
state = None
auth_chain = []
if thread_id is None:
# FIXME: Pick something better?
thread_id = random.randint(0, 999999999)
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
@ -259,7 +264,8 @@ class FederationHandler(BaseHandler):
)
yield self._get_missing_events_for_pdu(
origin, pdu, prevs, min_depth
origin, pdu, prevs, min_depth,
thread_id=thread_id,
)
# Update the set of things we've seen after trying to
@ -414,15 +420,24 @@ class FederationHandler(BaseHandler):
affected=event_id,
)
now = self.clock.time_msec()
if now - pdu.origin_server_ts > 2 * 60 * 1000:
pass
else:
thread_id = 0
logger.info("Thread ID %r", thread_id)
yield self._process_received_pdu(
origin,
pdu,
state=state,
auth_chain=auth_chain,
thread_id=thread_id,
)
@defer.inlineCallbacks
def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth, thread_id):
"""
Args:
origin (str): Origin of the pdu. Will be called to get the missing events
@ -529,6 +544,7 @@ class FederationHandler(BaseHandler):
origin,
ev,
sent_to_us_directly=False,
thread_id=thread_id,
)
except FederationError as e:
if e.code == 403:
@ -540,7 +556,7 @@ class FederationHandler(BaseHandler):
raise
@defer.inlineCallbacks
def _process_received_pdu(self, origin, event, state, auth_chain):
def _process_received_pdu(self, origin, event, state, auth_chain, thread_id):
""" Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler.
"""
@ -592,6 +608,7 @@ class FederationHandler(BaseHandler):
origin,
event,
state=state,
thread_id=thread_id,
)
except AuthError as e:
raise FederationError(
@ -1557,11 +1574,12 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def _handle_new_event(self, origin, event, state=None, auth_events=None,
backfilled=False):
backfilled=False, thread_id=0):
context = yield self._prep_event(
origin, event,
state=state,
auth_events=auth_events,
thread_id=thread_id,
)
# reraise does not allow inlineCallbacks to preserve the stacktrace, so we
@ -1720,7 +1738,7 @@ class FederationHandler(BaseHandler):
)
@defer.inlineCallbacks
def _prep_event(self, origin, event, state=None, auth_events=None):
def _prep_event(self, origin, event, state=None, auth_events=None, thread_id=0):
"""
Args:
@ -1733,7 +1751,7 @@ class FederationHandler(BaseHandler):
Deferred, which resolves to synapse.events.snapshot.EventContext
"""
context = yield self.state_handler.compute_event_context(
event, old_state=state,
event, old_state=state, thread_id=thread_id,
)
if not auth_events:

View file

@ -178,7 +178,7 @@ class StateHandler(object):
defer.returnValue(joined_hosts)
@defer.inlineCallbacks
def compute_event_context(self, event, old_state=None):
def compute_event_context(self, event, old_state=None, thread_id=0):
"""Build an EventContext structure for the event.
This works out what the current state should be for the event, and
@ -215,6 +215,7 @@ class StateHandler(object):
# We don't store state for outliers, so we don't generate a state
# group for it.
context = EventContext.with_state(
thread_id=0, # outlier, don't care
state_group=None,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
@ -251,6 +252,7 @@ class StateHandler(object):
)
context = EventContext.with_state(
thread_id=thread_id,
state_group=state_group,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
@ -319,6 +321,7 @@ class StateHandler(object):
state_group = entry.state_group
context = EventContext.with_state(
thread_id=thread_id,
state_group=state_group,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,

View file

@ -1282,8 +1282,9 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
"url" in event.content
and isinstance(event.content["url"], text_type)
),
"thread_id": ctx.thread_id,
}
for event, _ in events_and_contexts
for event, ctx in events_and_contexts
],
)

View file

@ -352,7 +352,7 @@ class EventsWorkerStore(SQLBaseStore):
run_in_background(
self._get_event_from_row,
row["internal_metadata"], row["json"], row["redacts"],
rejected_reason=row["rejects"],
rejected_reason=row["rejects"], thread_id=row["thread_id"],
)
for row in rows
],
@ -378,8 +378,10 @@ class EventsWorkerStore(SQLBaseStore):
" e.internal_metadata,"
" e.json,"
" r.redacts as redacts,"
" rej.event_id as rejects "
" rej.event_id as rejects, "
" ev.thread_id as thread_id"
" FROM event_json as e"
" INNER JOIN events as ev USING (event_id)"
" LEFT JOIN rejections as rej USING (event_id)"
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
" WHERE e.event_id IN (%s)"
@ -392,10 +394,11 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def _get_event_from_row(self, internal_metadata, js, redacted,
rejected_reason=None):
thread_id, rejected_reason=None):
with Measure(self._clock, "_get_event_from_row"):
d = json.loads(js)
internal_metadata = json.loads(internal_metadata)
internal_metadata["thread_id"] = thread_id
if rejected_reason:
rejected_reason = yield self._simple_select_one_onecol(
@ -411,6 +414,8 @@ class EventsWorkerStore(SQLBaseStore):
rejected_reason=rejected_reason,
)
original_ev.unsigned["thread_id"] = thread_id
redacted_event = None
if redacted:
redacted_event = prune_event(original_ev)