Supply auth_chain along with current state in '/state/', fetch auth events from a remote server if we are missing some of them

This commit is contained in:
Mark Haines 2014-12-18 18:47:13 +00:00
parent dbe77ec79a
commit 041ac476a5
4 changed files with 56 additions and 32 deletions

View file

@ -256,31 +256,35 @@ class ReplicationLayer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_state_for_context(self, destination, context, event_id=None): def get_state_for_context(self, destination, context, event_id):
"""Requests all of the `current` state PDUs for a given context from """Requests all of the `current` state PDUs for a given context from
a remote home server. a remote home server.
Args: Args:
destination (str): The remote homeserver to query for the state. destination (str): The remote homeserver to query for the state.
context (str): The context we're interested in. context (str): The context we're interested in.
event_id (str): The id of the event we want the state at.
Returns: Returns:
Deferred: Results in a list of PDUs. Deferred: Results in a list of PDUs.
""" """
transaction_data = yield self.transport_layer.get_context_state( result = yield self.transport_layer.get_context_state(
destination, destination,
context, context,
event_id=event_id, event_id=event_id,
) )
transaction = Transaction(**transaction_data)
pdus = [ pdus = [
self.event_from_pdu_json(p, outlier=True) self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
for p in transaction.pdus
] ]
defer.returnValue(pdus) auth_chain = [
self.event_from_pdu_json(p, outlier=True)
for p in result.get("auth_chain", [])
]
defer.returnValue((pdus, auth_chain))
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -383,10 +387,16 @@ class ReplicationLayer(object):
context, context,
event_id, event_id,
) )
auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus]
)
else: else:
raise NotImplementedError("Specify an event") raise NotImplementedError("Specify an event")
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) defer.returnValue((200, {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
}))
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -573,6 +583,8 @@ class ReplicationLayer(object):
state = None state = None
auth_chain = []
# We need to make sure we have all the auth events. # We need to make sure we have all the auth events.
# for e_id, _ in pdu.auth_events: # for e_id, _ in pdu.auth_events:
# exists = yield self._get_persisted_pdu( # exists = yield self._get_persisted_pdu(
@ -645,7 +657,7 @@ class ReplicationLayer(object):
"_handle_new_pdu getting state for %s", "_handle_new_pdu getting state for %s",
pdu.room_id pdu.room_id
) )
state = yield self.get_state_for_context( state, auth_chain = yield self.get_state_for_context(
origin, pdu.room_id, pdu.event_id, origin, pdu.room_id, pdu.event_id,
) )
@ -655,6 +667,7 @@ class ReplicationLayer(object):
pdu, pdu,
backfilled=backfilled, backfilled=backfilled,
state=state, state=state,
auth_chain=auth_chain,
) )
else: else:
ret = None ret = None

View file

@ -95,7 +95,8 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def on_receive_pdu(self, origin, pdu, backfilled, state=None): def on_receive_pdu(self, origin, pdu, backfilled, state=None,
auth_chain=None):
""" Called by the ReplicationLayer when we have a new pdu. We need to """ Called by the ReplicationLayer when we have a new pdu. We need to
do auth checks and put it through the StateHandler. do auth checks and put it through the StateHandler.
""" """
@ -150,8 +151,15 @@ class FederationHandler(BaseHandler):
if not is_in_room and not event.internal_metadata.outlier: if not is_in_room and not event.internal_metadata.outlier:
logger.debug("Got event for room we're not in.") logger.debug("Got event for room we're not in.")
replication_layer = self.replication_layer replication = self.replication_layer
auth_chain = yield replication_layer.get_event_auth(
if not state:
state, auth_chain = yield replication.get_state_for_context(
origin, context=event.room_id, event_id=event.event_id,
)
if not auth_chain:
auth_chain = yield replication.get_event_auth(
origin, origin,
context=event.room_id, context=event.room_id,
event_id=event.event_id, event_id=event.event_id,
@ -160,25 +168,18 @@ class FederationHandler(BaseHandler):
for e in auth_chain: for e in auth_chain:
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
try: try:
yield self._handle_new_event(e, fetch_missing=False) yield self._handle_new_event(e, fetch_auth_from=origin)
except: except:
logger.exception( logger.exception(
"Failed to handle auth event %s", "Failed to handle auth event %s",
e.event_id, e.event_id,
) )
if not state:
state = yield replication_layer.get_state_for_context(
origin,
context=event.room_id,
event_id=event.event_id,
)
# FIXME: Get auth chain for these state events
current_state = state current_state = state
if state: if state:
for e in state: for e in state:
logging.info("A :) %r", e)
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
try: try:
yield self._handle_new_event(e) yield self._handle_new_event(e)
@ -392,7 +393,7 @@ class FederationHandler(BaseHandler):
for e in auth_chain: for e in auth_chain:
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
try: try:
yield self._handle_new_event(e, fetch_missing=False) yield self._handle_new_event(e)
except: except:
logger.exception( logger.exception(
"Failed to handle auth event %s", "Failed to handle auth event %s",
@ -404,8 +405,7 @@ class FederationHandler(BaseHandler):
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
try: try:
yield self._handle_new_event( yield self._handle_new_event(
e, e, fetch_auth_from=target_host
fetch_missing=True
) )
except: except:
logger.exception( logger.exception(
@ -682,7 +682,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_new_event(self, event, state=None, backfilled=False, def _handle_new_event(self, event, state=None, backfilled=False,
current_state=None, fetch_missing=True): current_state=None, fetch_auth_from=None):
logger.debug( logger.debug(
"_handle_new_event: Before annotate: %s, sigs: %s", "_handle_new_event: Before annotate: %s, sigs: %s",
@ -703,11 +703,20 @@ class FederationHandler(BaseHandler):
known_ids = set( known_ids = set(
[s.event_id for s in context.auth_events.values()] [s.event_id for s in context.auth_events.values()]
) )
for e_id, _ in event.auth_events: for e_id, _ in event.auth_events:
if e_id not in known_ids: if e_id not in known_ids:
e = yield self.store.get_event( e = yield self.store.get_event(e_id, allow_none=True)
e_id, allow_none=True,
if not e and fetch_auth_from is not None:
# Grab the auth_chain over federation if we are missing
# auth events.
auth_chain = yield self.replication_layer.get_event_auth(
fetch_auth_from, event.event_id, event.room_id
) )
for auth_event in auth_chain:
yield self._handle_new_event(auth_event)
e = yield self.store.get_event(e_id, allow_none=True)
if not e: if not e:
# TODO: Do some conflict res to make sure that we're # TODO: Do some conflict res to make sure that we're

View file

@ -120,5 +120,5 @@ class Signal(object):
results = [] results = []
for deferred in deferreds: for deferred in deferreds:
result = yield deferred result = yield deferred
results.append(results) results.append(result)
defer.returnValue(results) defer.returnValue(results)

View file

@ -52,6 +52,7 @@ class FederationTestCase(unittest.TestCase):
"get_received_txn_response", "get_received_txn_response",
"set_received_txn_response", "set_received_txn_response",
"get_destination_retry_timings", "get_destination_retry_timings",
"get_auth_chain",
]) ])
self.mock_persistence.get_received_txn_response.return_value = ( self.mock_persistence.get_received_txn_response.return_value = (
defer.succeed(None) defer.succeed(None)
@ -59,6 +60,7 @@ class FederationTestCase(unittest.TestCase):
self.mock_persistence.get_destination_retry_timings.return_value = ( self.mock_persistence.get_destination_retry_timings.return_value = (
defer.succeed(DestinationsTable.EntryType("", 0, 0)) defer.succeed(DestinationsTable.EntryType("", 0, 0))
) )
self.mock_persistence.get_auth_chain.return_value = []
self.mock_config = Mock() self.mock_config = Mock()
self.mock_config.signing_key = [MockKey()] self.mock_config.signing_key = [MockKey()]
self.clock = MockClock() self.clock = MockClock()