Use fast path

This commit is contained in:
Erik Johnston 2024-09-15 11:25:23 +01:00
parent 35d797a9c4
commit fb751d3914
2 changed files with 112 additions and 19 deletions

View file

@ -14,11 +14,12 @@
import logging
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, cast
from typing import TYPE_CHECKING, Collection, Dict, List, Mapping, Optional, Set, cast
import attr
from synapse.api.errors import SlidingSyncUnknownPosition
from synapse.events import EventBase
from synapse.logging.opentracing import log_kv
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
@ -451,6 +452,38 @@ class SlidingSyncStore(SQLBaseStore):
room_configs=room_configs,
)
async def get_visibility_for_events(
self, room_id: str, events: Collection[EventBase]
) -> Mapping[str, Optional[str]]:
def get_visibility_for_events_txn(
txn: LoggingTransaction,
) -> Mapping[str, Optional[str]]:
sql = """
SELECT visibility FROM history_visibility_ranges
WHERE start_range <= ? AND (? < end_range OR end_range IS NULL)
AND room_id = ?
"""
results = {}
for event in events:
txn.execute(
sql,
(
event.internal_metadata.stream_ordering,
event.internal_metadata.stream_ordering,
room_id,
),
)
row = txn.fetchone()
if row is not None:
results[event.event_id] = row[0]
return results
return await self.db_pool.runInteraction(
"get_visibility_for_events", get_visibility_for_events_txn
)
@attr.s(auto_attribs=True, frozen=True)
class PerConnectionStateDB:

View file

@ -105,6 +105,9 @@ async def filter_events_for_client(
The filtered events. The `unsigned` data is annotated with the membership state
of `user_id` at each event.
"""
if not events:
return []
# Filter out events that have been soft failed so that we don't relay them
# to clients.
events_before_filtering = events
@ -117,13 +120,38 @@ async def filter_events_for_client(
[event.event_id for event in events],
)
types = (_HISTORY_VIS_KEY, (EventTypes.Member, user_id))
types = (
_HISTORY_VIS_KEY,
(EventTypes.Member, user_id),
)
room_id = events[0].room_id
assert all(event.room_id == room_id for event in events)
visibilities: Dict[str, str] = {}
memberships: Dict[str, Optional[EventBase]] = {}
events_to_fetch = {e.event_id for e in events if not e.internal_metadata.outlier}
if not is_peeking:
fetched_visibilities = await storage.main.get_visibility_for_events(
room_id, [e for e in events if not e.internal_metadata.outlier]
)
for event_id, visibility in fetched_visibilities.items():
if visibility in (
HistoryVisibility.SHARED,
HistoryVisibility.WORLD_READABLE,
):
events_to_fetch.discard(event_id)
visibilities[event_id] = visibility
# we exclude outliers at this point, and then handle them separately later
event_id_to_state = await storage.state.get_state_for_events(
frozenset(e.event_id for e in events if not e.internal_metadata.outlier),
state_filter=StateFilter.from_types(types),
)
if events_to_fetch:
event_id_to_state = await storage.state.get_state_for_events(
events_to_fetch,
state_filter=StateFilter.from_types(types),
)
for event_id, state in event_id_to_state.items():
visibilities[event_id] = get_effective_room_visibility_from_state(state)
memberships[event_id] = state.get((EventTypes.Member, user_id))
# Get the users who are ignored by the requesting user.
ignore_list = await storage.main.ignored_users(user_id)
@ -140,8 +168,8 @@ async def filter_events_for_client(
] = await storage.main.get_retention_policy_for_room(room_id)
def allowed(event: EventBase) -> Optional[EventBase]:
state_after_event = event_id_to_state.get(event.event_id)
filtered = _check_client_allowed_to_see_event(
# state_after_event = event_id_to_state.get(event.event_id)
filtered = _check_client_allowed_to_see_event_with_state(
user_id=user_id,
event=event,
clock=storage.main.clock,
@ -149,9 +177,10 @@ async def filter_events_for_client(
sender_ignored=event.sender in ignore_list,
always_include_ids=always_include_ids,
retention_policy=retention_policies[event.room_id],
state=state_after_event,
is_peeking=is_peeking,
sender_erased=erased_senders.get(event.sender, False),
visibility=visibilities[event.event_id],
membership_event=memberships.get(event.event_id),
)
if filtered is None:
return None
@ -165,11 +194,9 @@ async def filter_events_for_client(
user_membership_event: Optional[EventBase]
if event.type == EventTypes.Member and event.state_key == user_id:
user_membership_event = event
elif state_after_event is not None:
user_membership_event = state_after_event.get((EventTypes.Member, user_id))
else:
# unreachable!
raise Exception("Missing state for event that is not user's own membership")
# TODO: Actually get the proper membership
user_membership_event = memberships.get(event_id)
user_membership = (
user_membership_event.membership
@ -353,6 +380,41 @@ def _check_client_allowed_to_see_event(
the original event if they can see it as normal.
"""
visibility = HistoryVisibility.SHARED
if state is not None:
visibility = get_effective_room_visibility_from_state(state)
membership_event = state.get((EventTypes.Member, user_id)) if state else None
return _check_client_allowed_to_see_event_with_state(
user_id,
event,
clock,
filter_send_to_client,
is_peeking,
always_include_ids,
sender_ignored,
retention_policy,
sender_erased,
visibility=visibility,
membership_event=membership_event,
)
def _check_client_allowed_to_see_event_with_state(
user_id: str,
event: EventBase,
clock: Clock,
filter_send_to_client: bool,
is_peeking: bool,
always_include_ids: FrozenSet[str],
sender_ignored: bool,
retention_policy: RetentionPolicy,
sender_erased: bool,
visibility: str,
membership_event: Optional[EventBase],
) -> Optional[EventBase]:
# Only run some checks if these events aren't about to be sent to clients. This is
# because, if this is not the case, we're probably only checking if the users can
# see events in the room at that point in the DAG, and that shouldn't be decided
@ -390,12 +452,6 @@ def _check_client_allowed_to_see_event(
)
return None
if state is None:
raise Exception("Missing state for non-outlier event")
# get the room_visibility at the time of the event.
visibility = get_effective_room_visibility_from_state(state)
# Check if the room has lax history visibility, allowing us to skip
# membership checks.
#
@ -408,6 +464,10 @@ def _check_client_allowed_to_see_event(
):
return event
if membership_event:
state = {(EventTypes.Member, user_id): membership_event}
else:
state = {}
membership_result = _check_membership(user_id, event, visibility, state, is_peeking)
if not membership_result.allowed:
filtered_event_logger.debug(