mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-25 19:15:51 +03:00
Move methods involving event authentication to EventAuthHandler. (#10268)
Instead of mixing them with user authentication methods.
This commit is contained in:
parent
0aab50c772
commit
8d609435c0
11 changed files with 112 additions and 106 deletions
1
changelog.d/10268.misc
Normal file
1
changelog.d/10268.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Move event authentication methods from `Auth` to `EventAuthHandler`.
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Optional, Tuple
|
||||||
|
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
from netaddr import IPAddress
|
from netaddr import IPAddress
|
||||||
|
@ -28,10 +28,8 @@ from synapse.api.errors import (
|
||||||
InvalidClientTokenError,
|
InvalidClientTokenError,
|
||||||
MissingClientTokenError,
|
MissingClientTokenError,
|
||||||
)
|
)
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.builder import EventBuilder
|
|
||||||
from synapse.http import get_request_user_agent
|
from synapse.http import get_request_user_agent
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging import opentracing as opentracing
|
from synapse.logging import opentracing as opentracing
|
||||||
|
@ -39,7 +37,6 @@ from synapse.storage.databases.main.registration import TokenLookupResult
|
||||||
from synapse.types import Requester, StateMap, UserID, create_requester
|
from synapse.types import Requester, StateMap, UserID, create_requester
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
|
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
|
||||||
from synapse.util.metrics import Measure
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -47,15 +44,6 @@ if TYPE_CHECKING:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
AuthEventTypes = (
|
|
||||||
EventTypes.Create,
|
|
||||||
EventTypes.Member,
|
|
||||||
EventTypes.PowerLevels,
|
|
||||||
EventTypes.JoinRules,
|
|
||||||
EventTypes.RoomHistoryVisibility,
|
|
||||||
EventTypes.ThirdPartyInvite,
|
|
||||||
)
|
|
||||||
|
|
||||||
# guests always get this device id.
|
# guests always get this device id.
|
||||||
GUEST_DEVICE_ID = "guest_device"
|
GUEST_DEVICE_ID = "guest_device"
|
||||||
|
|
||||||
|
@ -66,9 +54,7 @@ class _InvalidMacaroonException(Exception):
|
||||||
|
|
||||||
class Auth:
|
class Auth:
|
||||||
"""
|
"""
|
||||||
FIXME: This class contains a mix of functions for authenticating users
|
This class contains functions for authenticating users of our client-server API.
|
||||||
of our client-server API and authenticating events added to room graphs.
|
|
||||||
The latter should be moved to synapse.handlers.event_auth.EventAuthHandler.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
@ -90,18 +76,6 @@ class Auth:
|
||||||
self._macaroon_secret_key = hs.config.macaroon_secret_key
|
self._macaroon_secret_key = hs.config.macaroon_secret_key
|
||||||
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
|
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
|
||||||
|
|
||||||
async def check_from_context(
|
|
||||||
self, room_version: str, event, context, do_sig_check=True
|
|
||||||
) -> None:
|
|
||||||
auth_event_ids = event.auth_event_ids()
|
|
||||||
auth_events_by_id = await self.store.get_events(auth_event_ids)
|
|
||||||
auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
|
|
||||||
|
|
||||||
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
|
||||||
event_auth.check(
|
|
||||||
room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check
|
|
||||||
)
|
|
||||||
|
|
||||||
async def check_user_in_room(
|
async def check_user_in_room(
|
||||||
self,
|
self,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
|
@ -152,13 +126,6 @@ class Auth:
|
||||||
|
|
||||||
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
|
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
|
||||||
|
|
||||||
async def check_host_in_room(self, room_id: str, host: str) -> bool:
|
|
||||||
with Measure(self.clock, "check_host_in_room"):
|
|
||||||
return await self.store.is_host_joined(room_id, host)
|
|
||||||
|
|
||||||
def get_public_keys(self, invite_event: EventBase) -> List[Dict[str, Any]]:
|
|
||||||
return event_auth.get_public_keys(invite_event)
|
|
||||||
|
|
||||||
async def get_user_by_req(
|
async def get_user_by_req(
|
||||||
self,
|
self,
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
|
@ -489,44 +456,6 @@ class Auth:
|
||||||
"""
|
"""
|
||||||
return await self.store.is_server_admin(user)
|
return await self.store.is_server_admin(user)
|
||||||
|
|
||||||
def compute_auth_events(
|
|
||||||
self,
|
|
||||||
event: Union[EventBase, EventBuilder],
|
|
||||||
current_state_ids: StateMap[str],
|
|
||||||
for_verification: bool = False,
|
|
||||||
) -> List[str]:
|
|
||||||
"""Given an event and current state return the list of event IDs used
|
|
||||||
to auth an event.
|
|
||||||
|
|
||||||
If `for_verification` is False then only return auth events that
|
|
||||||
should be added to the event's `auth_events`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of event IDs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if event.type == EventTypes.Create:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Currently we ignore the `for_verification` flag even though there are
|
|
||||||
# some situations where we can drop particular auth events when adding
|
|
||||||
# to the event's `auth_events` (e.g. joins pointing to previous joins
|
|
||||||
# when room is publicly joinable). Dropping event IDs has the
|
|
||||||
# advantage that the auth chain for the room grows slower, but we use
|
|
||||||
# the auth chain in state resolution v2 to order events, which means
|
|
||||||
# care must be taken if dropping events to ensure that it doesn't
|
|
||||||
# introduce undesirable "state reset" behaviour.
|
|
||||||
#
|
|
||||||
# All of which sounds a bit tricky so we don't bother for now.
|
|
||||||
|
|
||||||
auth_ids = []
|
|
||||||
for etype, state_key in event_auth.auth_types_for_event(event):
|
|
||||||
auth_ev_id = current_state_ids.get((etype, state_key))
|
|
||||||
if auth_ev_id:
|
|
||||||
auth_ids.append(auth_ev_id)
|
|
||||||
|
|
||||||
return auth_ids
|
|
||||||
|
|
||||||
async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool:
|
async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool:
|
||||||
"""Determine whether the user is allowed to edit the room's entry in the
|
"""Determine whether the user is allowed to edit the room's entry in the
|
||||||
published room list.
|
published room list.
|
||||||
|
|
|
@ -34,7 +34,7 @@ from synapse.util import Clock
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.api.auth import Auth
|
from synapse.handlers.event_auth import EventAuthHandler
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -66,7 +66,7 @@ class EventBuilder:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_state: StateHandler
|
_state: StateHandler
|
||||||
_auth: "Auth"
|
_event_auth_handler: "EventAuthHandler"
|
||||||
_store: DataStore
|
_store: DataStore
|
||||||
_clock: Clock
|
_clock: Clock
|
||||||
_hostname: str
|
_hostname: str
|
||||||
|
@ -125,7 +125,9 @@ class EventBuilder:
|
||||||
state_ids = await self._state.get_current_state_ids(
|
state_ids = await self._state.get_current_state_ids(
|
||||||
self.room_id, prev_event_ids
|
self.room_id, prev_event_ids
|
||||||
)
|
)
|
||||||
auth_event_ids = self._auth.compute_auth_events(self, state_ids)
|
auth_event_ids = self._event_auth_handler.compute_auth_events(
|
||||||
|
self, state_ids
|
||||||
|
)
|
||||||
|
|
||||||
format_version = self.room_version.event_format
|
format_version = self.room_version.event_format
|
||||||
if format_version == EventFormatVersions.V1:
|
if format_version == EventFormatVersions.V1:
|
||||||
|
@ -193,7 +195,7 @@ class EventBuilderFactory:
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self.auth = hs.get_auth()
|
self._event_auth_handler = hs.get_event_auth_handler()
|
||||||
|
|
||||||
def new(self, room_version: str, key_values: dict) -> EventBuilder:
|
def new(self, room_version: str, key_values: dict) -> EventBuilder:
|
||||||
"""Generate an event builder appropriate for the given room version
|
"""Generate an event builder appropriate for the given room version
|
||||||
|
@ -229,7 +231,7 @@ class EventBuilderFactory:
|
||||||
return EventBuilder(
|
return EventBuilder(
|
||||||
store=self.store,
|
store=self.store,
|
||||||
state=self.state,
|
state=self.state,
|
||||||
auth=self.auth,
|
event_auth_handler=self._event_auth_handler,
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
hostname=self.hostname,
|
hostname=self.hostname,
|
||||||
signing_key=self.signing_key,
|
signing_key=self.signing_key,
|
||||||
|
|
|
@ -108,9 +108,9 @@ class FederationServer(FederationBase):
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
|
||||||
self.auth = hs.get_auth()
|
|
||||||
self.handler = hs.get_federation_handler()
|
self.handler = hs.get_federation_handler()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
|
self._event_auth_handler = hs.get_event_auth_handler()
|
||||||
|
|
||||||
self.device_handler = hs.get_device_handler()
|
self.device_handler = hs.get_device_handler()
|
||||||
|
|
||||||
|
@ -420,7 +420,7 @@ class FederationServer(FederationBase):
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
await self.check_server_matches_acl(origin_host, room_id)
|
await self.check_server_matches_acl(origin_host, room_id)
|
||||||
|
|
||||||
in_room = await self.auth.check_host_in_room(room_id, origin)
|
in_room = await self._event_auth_handler.check_host_in_room(room_id, origin)
|
||||||
if not in_room:
|
if not in_room:
|
||||||
raise AuthError(403, "Host not in room.")
|
raise AuthError(403, "Host not in room.")
|
||||||
|
|
||||||
|
@ -453,7 +453,7 @@ class FederationServer(FederationBase):
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
await self.check_server_matches_acl(origin_host, room_id)
|
await self.check_server_matches_acl(origin_host, room_id)
|
||||||
|
|
||||||
in_room = await self.auth.check_host_in_room(room_id, origin)
|
in_room = await self._event_auth_handler.check_host_in_room(room_id, origin)
|
||||||
if not in_room:
|
if not in_room:
|
||||||
raise AuthError(403, "Host not in room.")
|
raise AuthError(403, "Host not in room.")
|
||||||
|
|
||||||
|
|
|
@ -11,8 +11,9 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import TYPE_CHECKING, Collection, Optional
|
from typing import TYPE_CHECKING, Collection, List, Optional, Union
|
||||||
|
|
||||||
|
from synapse import event_auth
|
||||||
from synapse.api.constants import (
|
from synapse.api.constants import (
|
||||||
EventTypes,
|
EventTypes,
|
||||||
JoinRules,
|
JoinRules,
|
||||||
|
@ -20,9 +21,11 @@ from synapse.api.constants import (
|
||||||
RestrictedJoinRuleTypes,
|
RestrictedJoinRuleTypes,
|
||||||
)
|
)
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
from synapse.api.room_versions import RoomVersion
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
|
from synapse.events.builder import EventBuilder
|
||||||
from synapse.types import StateMap
|
from synapse.types import StateMap
|
||||||
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -34,8 +37,63 @@ class EventAuthHandler:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
self._clock = hs.get_clock()
|
||||||
self._store = hs.get_datastore()
|
self._store = hs.get_datastore()
|
||||||
|
|
||||||
|
async def check_from_context(
|
||||||
|
self, room_version: str, event, context, do_sig_check=True
|
||||||
|
) -> None:
|
||||||
|
auth_event_ids = event.auth_event_ids()
|
||||||
|
auth_events_by_id = await self._store.get_events(auth_event_ids)
|
||||||
|
auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
|
||||||
|
|
||||||
|
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
||||||
|
event_auth.check(
|
||||||
|
room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_auth_events(
|
||||||
|
self,
|
||||||
|
event: Union[EventBase, EventBuilder],
|
||||||
|
current_state_ids: StateMap[str],
|
||||||
|
for_verification: bool = False,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Given an event and current state return the list of event IDs used
|
||||||
|
to auth an event.
|
||||||
|
|
||||||
|
If `for_verification` is False then only return auth events that
|
||||||
|
should be added to the event's `auth_events`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of event IDs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if event.type == EventTypes.Create:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Currently we ignore the `for_verification` flag even though there are
|
||||||
|
# some situations where we can drop particular auth events when adding
|
||||||
|
# to the event's `auth_events` (e.g. joins pointing to previous joins
|
||||||
|
# when room is publicly joinable). Dropping event IDs has the
|
||||||
|
# advantage that the auth chain for the room grows slower, but we use
|
||||||
|
# the auth chain in state resolution v2 to order events, which means
|
||||||
|
# care must be taken if dropping events to ensure that it doesn't
|
||||||
|
# introduce undesirable "state reset" behaviour.
|
||||||
|
#
|
||||||
|
# All of which sounds a bit tricky so we don't bother for now.
|
||||||
|
|
||||||
|
auth_ids = []
|
||||||
|
for etype, state_key in event_auth.auth_types_for_event(event):
|
||||||
|
auth_ev_id = current_state_ids.get((etype, state_key))
|
||||||
|
if auth_ev_id:
|
||||||
|
auth_ids.append(auth_ev_id)
|
||||||
|
|
||||||
|
return auth_ids
|
||||||
|
|
||||||
|
async def check_host_in_room(self, room_id: str, host: str) -> bool:
|
||||||
|
with Measure(self._clock, "check_host_in_room"):
|
||||||
|
return await self._store.is_host_joined(room_id, host)
|
||||||
|
|
||||||
async def check_restricted_join_rules(
|
async def check_restricted_join_rules(
|
||||||
self,
|
self,
|
||||||
state_ids: StateMap[str],
|
state_ids: StateMap[str],
|
||||||
|
|
|
@ -250,7 +250,9 @@ class FederationHandler(BaseHandler):
|
||||||
#
|
#
|
||||||
# Note that if we were never in the room then we would have already
|
# Note that if we were never in the room then we would have already
|
||||||
# dropped the event, since we wouldn't know the room version.
|
# dropped the event, since we wouldn't know the room version.
|
||||||
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
|
is_in_room = await self._event_auth_handler.check_host_in_room(
|
||||||
|
room_id, self.server_name
|
||||||
|
)
|
||||||
if not is_in_room:
|
if not is_in_room:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Ignoring PDU from %s as we're not in the room",
|
"Ignoring PDU from %s as we're not in the room",
|
||||||
|
@ -1674,7 +1676,9 @@ class FederationHandler(BaseHandler):
|
||||||
room_version = await self.store.get_room_version_id(room_id)
|
room_version = await self.store.get_room_version_id(room_id)
|
||||||
|
|
||||||
# now check that we are *still* in the room
|
# now check that we are *still* in the room
|
||||||
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
|
is_in_room = await self._event_auth_handler.check_host_in_room(
|
||||||
|
room_id, self.server_name
|
||||||
|
)
|
||||||
if not is_in_room:
|
if not is_in_room:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Got /make_join request for room %s we are no longer in",
|
"Got /make_join request for room %s we are no longer in",
|
||||||
|
@ -1705,7 +1709,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
||||||
# when we get the event back in `on_send_join_request`
|
# when we get the event back in `on_send_join_request`
|
||||||
await self.auth.check_from_context(
|
await self._event_auth_handler.check_from_context(
|
||||||
room_version, event, context, do_sig_check=False
|
room_version, event, context, do_sig_check=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1877,7 +1881,7 @@ class FederationHandler(BaseHandler):
|
||||||
try:
|
try:
|
||||||
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
||||||
# when we get the event back in `on_send_leave_request`
|
# when we get the event back in `on_send_leave_request`
|
||||||
await self.auth.check_from_context(
|
await self._event_auth_handler.check_from_context(
|
||||||
room_version, event, context, do_sig_check=False
|
room_version, event, context, do_sig_check=False
|
||||||
)
|
)
|
||||||
except AuthError as e:
|
except AuthError as e:
|
||||||
|
@ -1939,7 +1943,7 @@ class FederationHandler(BaseHandler):
|
||||||
try:
|
try:
|
||||||
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
||||||
# when we get the event back in `on_send_knock_request`
|
# when we get the event back in `on_send_knock_request`
|
||||||
await self.auth.check_from_context(
|
await self._event_auth_handler.check_from_context(
|
||||||
room_version, event, context, do_sig_check=False
|
room_version, event, context, do_sig_check=False
|
||||||
)
|
)
|
||||||
except AuthError as e:
|
except AuthError as e:
|
||||||
|
@ -2111,7 +2115,7 @@ class FederationHandler(BaseHandler):
|
||||||
async def on_backfill_request(
|
async def on_backfill_request(
|
||||||
self, origin: str, room_id: str, pdu_list: List[str], limit: int
|
self, origin: str, room_id: str, pdu_list: List[str], limit: int
|
||||||
) -> List[EventBase]:
|
) -> List[EventBase]:
|
||||||
in_room = await self.auth.check_host_in_room(room_id, origin)
|
in_room = await self._event_auth_handler.check_host_in_room(room_id, origin)
|
||||||
if not in_room:
|
if not in_room:
|
||||||
raise AuthError(403, "Host not in room.")
|
raise AuthError(403, "Host not in room.")
|
||||||
|
|
||||||
|
@ -2146,7 +2150,9 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
if event:
|
if event:
|
||||||
in_room = await self.auth.check_host_in_room(event.room_id, origin)
|
in_room = await self._event_auth_handler.check_host_in_room(
|
||||||
|
event.room_id, origin
|
||||||
|
)
|
||||||
if not in_room:
|
if not in_room:
|
||||||
raise AuthError(403, "Host not in room.")
|
raise AuthError(403, "Host not in room.")
|
||||||
|
|
||||||
|
@ -2499,7 +2505,7 @@ class FederationHandler(BaseHandler):
|
||||||
latest_events: List[str],
|
latest_events: List[str],
|
||||||
limit: int,
|
limit: int,
|
||||||
) -> List[EventBase]:
|
) -> List[EventBase]:
|
||||||
in_room = await self.auth.check_host_in_room(room_id, origin)
|
in_room = await self._event_auth_handler.check_host_in_room(room_id, origin)
|
||||||
if not in_room:
|
if not in_room:
|
||||||
raise AuthError(403, "Host not in room.")
|
raise AuthError(403, "Host not in room.")
|
||||||
|
|
||||||
|
@ -2562,7 +2568,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
if not auth_events:
|
if not auth_events:
|
||||||
prev_state_ids = await context.get_prev_state_ids()
|
prev_state_ids = await context.get_prev_state_ids()
|
||||||
auth_events_ids = self.auth.compute_auth_events(
|
auth_events_ids = self._event_auth_handler.compute_auth_events(
|
||||||
event, prev_state_ids, for_verification=True
|
event, prev_state_ids, for_verification=True
|
||||||
)
|
)
|
||||||
auth_events_x = await self.store.get_events(auth_events_ids)
|
auth_events_x = await self.store.get_events(auth_events_ids)
|
||||||
|
@ -2991,7 +2997,7 @@ class FederationHandler(BaseHandler):
|
||||||
"state_key": target_user_id,
|
"state_key": target_user_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
if await self.auth.check_host_in_room(room_id, self.hs.hostname):
|
if await self._event_auth_handler.check_host_in_room(room_id, self.hs.hostname):
|
||||||
room_version = await self.store.get_room_version_id(room_id)
|
room_version = await self.store.get_room_version_id(room_id)
|
||||||
builder = self.event_builder_factory.new(room_version, event_dict)
|
builder = self.event_builder_factory.new(room_version, event_dict)
|
||||||
|
|
||||||
|
@ -3011,7 +3017,9 @@ class FederationHandler(BaseHandler):
|
||||||
event.internal_metadata.send_on_behalf_of = self.hs.hostname
|
event.internal_metadata.send_on_behalf_of = self.hs.hostname
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.auth.check_from_context(room_version, event, context)
|
await self._event_auth_handler.check_from_context(
|
||||||
|
room_version, event, context
|
||||||
|
)
|
||||||
except AuthError as e:
|
except AuthError as e:
|
||||||
logger.warning("Denying new third party invite %r because %s", event, e)
|
logger.warning("Denying new third party invite %r because %s", event, e)
|
||||||
raise e
|
raise e
|
||||||
|
@ -3054,7 +3062,9 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.auth.check_from_context(room_version, event, context)
|
await self._event_auth_handler.check_from_context(
|
||||||
|
room_version, event, context
|
||||||
|
)
|
||||||
except AuthError as e:
|
except AuthError as e:
|
||||||
logger.warning("Denying third party invite %r because %s", event, e)
|
logger.warning("Denying third party invite %r because %s", event, e)
|
||||||
raise e
|
raise e
|
||||||
|
@ -3142,7 +3152,7 @@ class FederationHandler(BaseHandler):
|
||||||
last_exception = None # type: Optional[Exception]
|
last_exception = None # type: Optional[Exception]
|
||||||
|
|
||||||
# for each public key in the 3pid invite event
|
# for each public key in the 3pid invite event
|
||||||
for public_key_object in self.hs.get_auth().get_public_keys(invite_event):
|
for public_key_object in event_auth.get_public_keys(invite_event):
|
||||||
try:
|
try:
|
||||||
# for each sig on the third_party_invite block of the actual invite
|
# for each sig on the third_party_invite block of the actual invite
|
||||||
for server, signature_block in signed["signatures"].items():
|
for server, signature_block in signed["signatures"].items():
|
||||||
|
|
|
@ -385,6 +385,7 @@ class EventCreationHandler:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
self._event_auth_handler = hs.get_event_auth_handler()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.storage = hs.get_storage()
|
self.storage = hs.get_storage()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
|
@ -597,7 +598,7 @@ class EventCreationHandler:
|
||||||
(e.type, e.state_key): e.event_id for e in auth_events
|
(e.type, e.state_key): e.event_id for e in auth_events
|
||||||
}
|
}
|
||||||
# Actually strip down and use the necessary auth events
|
# Actually strip down and use the necessary auth events
|
||||||
auth_event_ids = self.auth.compute_auth_events(
|
auth_event_ids = self._event_auth_handler.compute_auth_events(
|
||||||
event=temp_event,
|
event=temp_event,
|
||||||
current_state_ids=auth_event_state_map,
|
current_state_ids=auth_event_state_map,
|
||||||
for_verification=False,
|
for_verification=False,
|
||||||
|
@ -1056,7 +1057,9 @@ class EventCreationHandler:
|
||||||
assert event.content["membership"] == Membership.LEAVE
|
assert event.content["membership"] == Membership.LEAVE
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
await self.auth.check_from_context(room_version, event, context)
|
await self._event_auth_handler.check_from_context(
|
||||||
|
room_version, event, context
|
||||||
|
)
|
||||||
except AuthError as err:
|
except AuthError as err:
|
||||||
logger.warning("Denying new event %r because %s", event, err)
|
logger.warning("Denying new event %r because %s", event, err)
|
||||||
raise err
|
raise err
|
||||||
|
@ -1381,7 +1384,7 @@ class EventCreationHandler:
|
||||||
raise AuthError(403, "Redacting server ACL events is not permitted")
|
raise AuthError(403, "Redacting server ACL events is not permitted")
|
||||||
|
|
||||||
prev_state_ids = await context.get_prev_state_ids()
|
prev_state_ids = await context.get_prev_state_ids()
|
||||||
auth_events_ids = self.auth.compute_auth_events(
|
auth_events_ids = self._event_auth_handler.compute_auth_events(
|
||||||
event, prev_state_ids, for_verification=True
|
event, prev_state_ids, for_verification=True
|
||||||
)
|
)
|
||||||
auth_events_map = await self.store.get_events(auth_events_ids)
|
auth_events_map = await self.store.get_events(auth_events_ids)
|
||||||
|
|
|
@ -83,6 +83,7 @@ class RoomCreationHandler(BaseHandler):
|
||||||
self.spam_checker = hs.get_spam_checker()
|
self.spam_checker = hs.get_spam_checker()
|
||||||
self.event_creation_handler = hs.get_event_creation_handler()
|
self.event_creation_handler = hs.get_event_creation_handler()
|
||||||
self.room_member_handler = hs.get_room_member_handler()
|
self.room_member_handler = hs.get_room_member_handler()
|
||||||
|
self._event_auth_handler = hs.get_event_auth_handler()
|
||||||
self.config = hs.config
|
self.config = hs.config
|
||||||
|
|
||||||
# Room state based off defined presets
|
# Room state based off defined presets
|
||||||
|
@ -226,7 +227,7 @@ class RoomCreationHandler(BaseHandler):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
old_room_version = await self.store.get_room_version_id(old_room_id)
|
old_room_version = await self.store.get_room_version_id(old_room_id)
|
||||||
await self.auth.check_from_context(
|
await self._event_auth_handler.check_from_context(
|
||||||
old_room_version, tombstone_event, tombstone_context
|
old_room_version, tombstone_event, tombstone_context
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -472,7 +472,7 @@ class SpaceSummaryHandler:
|
||||||
# If this is a request over federation, check if the host is in the room or
|
# If this is a request over federation, check if the host is in the room or
|
||||||
# is in one of the spaces specified via the join rules.
|
# is in one of the spaces specified via the join rules.
|
||||||
elif origin:
|
elif origin:
|
||||||
if await self._auth.check_host_in_room(room_id, origin):
|
if await self._event_auth_handler.check_host_in_room(room_id, origin):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Alternately, if the host has a user in any of the spaces specified
|
# Alternately, if the host has a user in any of the spaces specified
|
||||||
|
@ -485,7 +485,9 @@ class SpaceSummaryHandler:
|
||||||
await self._event_auth_handler.get_rooms_that_allow_join(state_ids)
|
await self._event_auth_handler.get_rooms_that_allow_join(state_ids)
|
||||||
)
|
)
|
||||||
for space_id in allowed_rooms:
|
for space_id in allowed_rooms:
|
||||||
if await self._auth.check_host_in_room(space_id, origin):
|
if await self._event_auth_handler.check_host_in_room(
|
||||||
|
space_id, origin
|
||||||
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# otherwise, check if the room is peekable
|
# otherwise, check if the room is peekable
|
||||||
|
|
|
@ -104,7 +104,7 @@ class BulkPushRuleEvaluator:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self._event_auth_handler = hs.get_event_auth_handler()
|
||||||
|
|
||||||
# Used by `RulesForRoom` to ensure only one thing mutates the cache at a
|
# Used by `RulesForRoom` to ensure only one thing mutates the cache at a
|
||||||
# time. Keyed off room_id.
|
# time. Keyed off room_id.
|
||||||
|
@ -172,7 +172,7 @@ class BulkPushRuleEvaluator:
|
||||||
# not having a power level event is an extreme edge case
|
# not having a power level event is an extreme edge case
|
||||||
auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
|
auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
|
||||||
else:
|
else:
|
||||||
auth_events_ids = self.auth.compute_auth_events(
|
auth_events_ids = self._event_auth_handler.compute_auth_events(
|
||||||
event, prev_state_ids, for_verification=False
|
event, prev_state_ids, for_verification=False
|
||||||
)
|
)
|
||||||
auth_events_dict = await self.store.get_events(auth_events_ids)
|
auth_events_dict = await self.store.get_events(auth_events_ids)
|
||||||
|
|
|
@ -734,7 +734,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self.auth = hs.get_auth()
|
self._event_auth_handler = hs.get_event_auth_handler()
|
||||||
|
|
||||||
# We don't actually check signatures in tests, so lets just create a
|
# We don't actually check signatures in tests, so lets just create a
|
||||||
# random key to use.
|
# random key to use.
|
||||||
|
@ -846,7 +846,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
builder = EventBuilder(
|
builder = EventBuilder(
|
||||||
state=self.state,
|
state=self.state,
|
||||||
auth=self.auth,
|
event_auth_handler=self._event_auth_handler,
|
||||||
store=self.store,
|
store=self.store,
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
hostname=hostname,
|
hostname=hostname,
|
||||||
|
|
Loading…
Reference in a new issue