mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-26 03:25:53 +03:00
Simplifications and comments in do_auth (#5227)
I was staring at this function trying to figure out wtf it was actually doing. This is (hopefully) a non-functional refactor which makes it a bit clearer.
This commit is contained in:
parent
1a94de60e8
commit
85d1e03b9d
3 changed files with 187 additions and 125 deletions
1
changelog.d/5227.misc
Normal file
1
changelog.d/5227.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Simplifications and comments in do_auth.
|
|
@ -2013,15 +2013,44 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
origin (str):
|
origin (str):
|
||||||
event (synapse.events.FrozenEvent):
|
event (synapse.events.EventBase):
|
||||||
context (synapse.events.snapshot.EventContext):
|
context (synapse.events.snapshot.EventContext):
|
||||||
auth_events (dict[(str, str)->str]):
|
auth_events (dict[(str, str)->synapse.events.EventBase]):
|
||||||
|
Map from (event_type, state_key) to event
|
||||||
|
|
||||||
|
What we expect the event's auth_events to be, based on the event's
|
||||||
|
position in the dag. I think? maybe??
|
||||||
|
|
||||||
|
Also NB that this function adds entries to it.
|
||||||
|
Returns:
|
||||||
|
defer.Deferred[None]
|
||||||
|
"""
|
||||||
|
room_version = yield self.store.get_room_version(event.room_id)
|
||||||
|
|
||||||
|
yield self._update_auth_events_and_context_for_auth(
|
||||||
|
origin, event, context, auth_events
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
self.auth.check(room_version, event, auth_events=auth_events)
|
||||||
|
except AuthError as e:
|
||||||
|
logger.warn("Failed auth resolution for %r because %s", event, e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _update_auth_events_and_context_for_auth(
|
||||||
|
self, origin, event, context, auth_events
|
||||||
|
):
|
||||||
|
"""Helper for do_auth. See there for docs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
origin (str):
|
||||||
|
event (synapse.events.EventBase):
|
||||||
|
context (synapse.events.snapshot.EventContext):
|
||||||
|
auth_events (dict[(str, str)->synapse.events.EventBase]):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred[None]
|
defer.Deferred[None]
|
||||||
"""
|
"""
|
||||||
# Check if we have all the auth events.
|
|
||||||
current_state = set(e.event_id for e in auth_events.values())
|
|
||||||
event_auth_events = set(event.auth_event_ids())
|
event_auth_events = set(event.auth_event_ids())
|
||||||
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
|
@ -2029,11 +2058,21 @@ class FederationHandler(BaseHandler):
|
||||||
else:
|
else:
|
||||||
event_key = None
|
event_key = None
|
||||||
|
|
||||||
if event_auth_events - current_state:
|
# if the event's auth_events refers to events which are not in our
|
||||||
|
# calculated auth_events, we need to fetch those events from somewhere.
|
||||||
|
#
|
||||||
|
# we start by fetching them from the store, and then try calling /event_auth/.
|
||||||
|
missing_auth = event_auth_events.difference(
|
||||||
|
e.event_id for e in auth_events.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
if missing_auth:
|
||||||
# TODO: can we use store.have_seen_events here instead?
|
# TODO: can we use store.have_seen_events here instead?
|
||||||
have_events = yield self.store.get_seen_events_with_rejections(
|
have_events = yield self.store.get_seen_events_with_rejections(
|
||||||
event_auth_events - current_state
|
missing_auth
|
||||||
)
|
)
|
||||||
|
logger.debug("Got events %s from store", have_events)
|
||||||
|
missing_auth.difference_update(have_events.keys())
|
||||||
else:
|
else:
|
||||||
have_events = {}
|
have_events = {}
|
||||||
|
|
||||||
|
@ -2042,13 +2081,12 @@ class FederationHandler(BaseHandler):
|
||||||
for e in auth_events.values()
|
for e in auth_events.values()
|
||||||
})
|
})
|
||||||
|
|
||||||
seen_events = set(have_events.keys())
|
|
||||||
|
|
||||||
missing_auth = event_auth_events - seen_events - current_state
|
|
||||||
|
|
||||||
if missing_auth:
|
if missing_auth:
|
||||||
logger.info("Missing auth: %s", missing_auth)
|
|
||||||
# If we don't have all the auth events, we need to get them.
|
# If we don't have all the auth events, we need to get them.
|
||||||
|
logger.info(
|
||||||
|
"auth_events contains unknown events: %s",
|
||||||
|
missing_auth,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
remote_auth_chain = yield self.federation_client.get_event_auth(
|
remote_auth_chain = yield self.federation_client.get_event_auth(
|
||||||
origin, event.room_id, event.event_id
|
origin, event.room_id, event.event_id
|
||||||
|
@ -2089,145 +2127,168 @@ class FederationHandler(BaseHandler):
|
||||||
have_events = yield self.store.get_seen_events_with_rejections(
|
have_events = yield self.store.get_seen_events_with_rejections(
|
||||||
event.auth_event_ids()
|
event.auth_event_ids()
|
||||||
)
|
)
|
||||||
seen_events = set(have_events.keys())
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# FIXME:
|
# FIXME:
|
||||||
logger.exception("Failed to get auth chain")
|
logger.exception("Failed to get auth chain")
|
||||||
|
|
||||||
|
if event.internal_metadata.is_outlier():
|
||||||
|
logger.info("Skipping auth_event fetch for outlier")
|
||||||
|
return
|
||||||
|
|
||||||
# FIXME: Assumes we have and stored all the state for all the
|
# FIXME: Assumes we have and stored all the state for all the
|
||||||
# prev_events
|
# prev_events
|
||||||
current_state = set(e.event_id for e in auth_events.values())
|
different_auth = event_auth_events.difference(
|
||||||
different_auth = event_auth_events - current_state
|
e.event_id for e in auth_events.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not different_auth:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"auth_events refers to events which are not in our calculated auth "
|
||||||
|
"chain: %s",
|
||||||
|
different_auth,
|
||||||
|
)
|
||||||
|
|
||||||
room_version = yield self.store.get_room_version(event.room_id)
|
room_version = yield self.store.get_room_version(event.room_id)
|
||||||
|
|
||||||
if different_auth and not event.internal_metadata.is_outlier():
|
different_events = yield logcontext.make_deferred_yieldable(
|
||||||
# Do auth conflict res.
|
defer.gatherResults([
|
||||||
logger.info("Different auth: %s", different_auth)
|
logcontext.run_in_background(
|
||||||
|
self.store.get_event,
|
||||||
different_events = yield logcontext.make_deferred_yieldable(
|
d,
|
||||||
defer.gatherResults([
|
allow_none=True,
|
||||||
logcontext.run_in_background(
|
allow_rejected=False,
|
||||||
self.store.get_event,
|
|
||||||
d,
|
|
||||||
allow_none=True,
|
|
||||||
allow_rejected=False,
|
|
||||||
)
|
|
||||||
for d in different_auth
|
|
||||||
if d in have_events and not have_events[d]
|
|
||||||
], consumeErrors=True)
|
|
||||||
).addErrback(unwrapFirstError)
|
|
||||||
|
|
||||||
if different_events:
|
|
||||||
local_view = dict(auth_events)
|
|
||||||
remote_view = dict(auth_events)
|
|
||||||
remote_view.update({
|
|
||||||
(d.type, d.state_key): d for d in different_events if d
|
|
||||||
})
|
|
||||||
|
|
||||||
new_state = yield self.state_handler.resolve_events(
|
|
||||||
room_version,
|
|
||||||
[list(local_view.values()), list(remote_view.values())],
|
|
||||||
event
|
|
||||||
)
|
)
|
||||||
|
for d in different_auth
|
||||||
|
if d in have_events and not have_events[d]
|
||||||
|
], consumeErrors=True)
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
auth_events.update(new_state)
|
if different_events:
|
||||||
|
local_view = dict(auth_events)
|
||||||
|
remote_view = dict(auth_events)
|
||||||
|
remote_view.update({
|
||||||
|
(d.type, d.state_key): d for d in different_events if d
|
||||||
|
})
|
||||||
|
|
||||||
current_state = set(e.event_id for e in auth_events.values())
|
new_state = yield self.state_handler.resolve_events(
|
||||||
different_auth = event_auth_events - current_state
|
room_version,
|
||||||
|
[list(local_view.values()), list(remote_view.values())],
|
||||||
|
event
|
||||||
|
)
|
||||||
|
|
||||||
yield self._update_context_for_auth_events(
|
logger.info(
|
||||||
event, context, auth_events, event_key,
|
"After state res: updating auth_events with new state %s",
|
||||||
)
|
{
|
||||||
|
(d.type, d.state_key): d.event_id for d in new_state.values()
|
||||||
|
if auth_events.get((d.type, d.state_key)) != d
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if different_auth and not event.internal_metadata.is_outlier():
|
auth_events.update(new_state)
|
||||||
logger.info("Different auth after resolution: %s", different_auth)
|
|
||||||
|
|
||||||
# Only do auth resolution if we have something new to say.
|
different_auth = event_auth_events.difference(
|
||||||
# We can't rove an auth failure.
|
e.event_id for e in auth_events.values()
|
||||||
do_resolution = False
|
)
|
||||||
|
|
||||||
provable = [
|
yield self._update_context_for_auth_events(
|
||||||
RejectedReason.NOT_ANCESTOR, RejectedReason.NOT_ANCESTOR,
|
event, context, auth_events, event_key,
|
||||||
]
|
)
|
||||||
|
|
||||||
for e_id in different_auth:
|
if not different_auth:
|
||||||
if e_id in have_events:
|
# we're done
|
||||||
if have_events[e_id] in provable:
|
return
|
||||||
do_resolution = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if do_resolution:
|
logger.info(
|
||||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
"auth_events still refers to events which are not in the calculated auth "
|
||||||
# 1. Get what we think is the auth chain.
|
"chain after state resolution: %s",
|
||||||
auth_ids = yield self.auth.compute_auth_events(
|
different_auth,
|
||||||
event, prev_state_ids
|
)
|
||||||
)
|
|
||||||
local_auth_chain = yield self.store.get_auth_chain(
|
|
||||||
auth_ids, include_given=True
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
# Only do auth resolution if we have something new to say.
|
||||||
# 2. Get remote difference.
|
# We can't prove an auth failure.
|
||||||
result = yield self.federation_client.query_auth(
|
do_resolution = False
|
||||||
origin,
|
|
||||||
event.room_id,
|
|
||||||
event.event_id,
|
|
||||||
local_auth_chain,
|
|
||||||
)
|
|
||||||
|
|
||||||
seen_remotes = yield self.store.have_seen_events(
|
for e_id in different_auth:
|
||||||
[e.event_id for e in result["auth_chain"]]
|
if e_id in have_events:
|
||||||
)
|
if have_events[e_id] == RejectedReason.NOT_ANCESTOR:
|
||||||
|
do_resolution = True
|
||||||
|
break
|
||||||
|
|
||||||
# 3. Process any remote auth chain events we haven't seen.
|
if not do_resolution:
|
||||||
for ev in result["auth_chain"]:
|
logger.info(
|
||||||
if ev.event_id in seen_remotes:
|
"Skipping auth resolution due to lack of provable rejection reasons"
|
||||||
continue
|
)
|
||||||
|
return
|
||||||
|
|
||||||
if ev.event_id == event.event_id:
|
logger.info("Doing auth resolution")
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||||
auth_ids = ev.auth_event_ids()
|
|
||||||
auth = {
|
|
||||||
(e.type, e.state_key): e
|
|
||||||
for e in result["auth_chain"]
|
|
||||||
if e.event_id in auth_ids
|
|
||||||
or event.type == EventTypes.Create
|
|
||||||
}
|
|
||||||
ev.internal_metadata.outlier = True
|
|
||||||
|
|
||||||
logger.debug(
|
# 1. Get what we think is the auth chain.
|
||||||
"do_auth %s different_auth: %s",
|
auth_ids = yield self.auth.compute_auth_events(
|
||||||
event.event_id, e.event_id
|
event, prev_state_ids
|
||||||
)
|
)
|
||||||
|
local_auth_chain = yield self.store.get_auth_chain(
|
||||||
yield self._handle_new_event(
|
auth_ids, include_given=True
|
||||||
origin, ev, auth_events=auth
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if ev.event_id in event_auth_events:
|
|
||||||
auth_events[(ev.type, ev.state_key)] = ev
|
|
||||||
except AuthError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
# FIXME:
|
|
||||||
logger.exception("Failed to query auth chain")
|
|
||||||
|
|
||||||
# 4. Look at rejects and their proofs.
|
|
||||||
# TODO.
|
|
||||||
|
|
||||||
yield self._update_context_for_auth_events(
|
|
||||||
event, context, auth_events, event_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.auth.check(room_version, event, auth_events=auth_events)
|
# 2. Get remote difference.
|
||||||
except AuthError as e:
|
result = yield self.federation_client.query_auth(
|
||||||
logger.warn("Failed auth resolution for %r because %s", event, e)
|
origin,
|
||||||
raise e
|
event.room_id,
|
||||||
|
event.event_id,
|
||||||
|
local_auth_chain,
|
||||||
|
)
|
||||||
|
|
||||||
|
seen_remotes = yield self.store.have_seen_events(
|
||||||
|
[e.event_id for e in result["auth_chain"]]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Process any remote auth chain events we haven't seen.
|
||||||
|
for ev in result["auth_chain"]:
|
||||||
|
if ev.event_id in seen_remotes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ev.event_id == event.event_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
auth_ids = ev.auth_event_ids()
|
||||||
|
auth = {
|
||||||
|
(e.type, e.state_key): e
|
||||||
|
for e in result["auth_chain"]
|
||||||
|
if e.event_id in auth_ids
|
||||||
|
or event.type == EventTypes.Create
|
||||||
|
}
|
||||||
|
ev.internal_metadata.outlier = True
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"do_auth %s different_auth: %s",
|
||||||
|
event.event_id, e.event_id
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self._handle_new_event(
|
||||||
|
origin, ev, auth_events=auth
|
||||||
|
)
|
||||||
|
|
||||||
|
if ev.event_id in event_auth_events:
|
||||||
|
auth_events[(ev.type, ev.state_key)] = ev
|
||||||
|
except AuthError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# FIXME:
|
||||||
|
logger.exception("Failed to query auth chain")
|
||||||
|
|
||||||
|
# 4. Look at rejects and their proofs.
|
||||||
|
# TODO.
|
||||||
|
|
||||||
|
yield self._update_context_for_auth_events(
|
||||||
|
event, context, auth_events, event_key,
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _update_context_for_auth_events(self, event, context, auth_events,
|
def _update_context_for_auth_events(self, event, context, auth_events,
|
||||||
|
|
|
@ -610,7 +610,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
return self.runInteraction("get_rejection_reasons", f)
|
return self.runInteraction("get_seen_events_with_rejections", f)
|
||||||
|
|
||||||
def _get_total_state_event_counts_txn(self, txn, room_id):
|
def _get_total_state_event_counts_txn(self, txn, room_id):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in a new issue