mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-18 00:43:30 +03:00
Use fast path
This commit is contained in:
parent
35d797a9c4
commit
fb751d3914
2 changed files with 112 additions and 19 deletions
|
@ -14,11 +14,12 @@
|
|||
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, cast
|
||||
from typing import TYPE_CHECKING, Collection, Dict, List, Mapping, Optional, Set, cast
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.errors import SlidingSyncUnknownPosition
|
||||
from synapse.events import EventBase
|
||||
from synapse.logging.opentracing import log_kv
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
|
@ -451,6 +452,38 @@ class SlidingSyncStore(SQLBaseStore):
|
|||
room_configs=room_configs,
|
||||
)
|
||||
|
||||
async def get_visibility_for_events(
|
||||
self, room_id: str, events: Collection[EventBase]
|
||||
) -> Mapping[str, Optional[str]]:
|
||||
def get_visibility_for_events_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Mapping[str, Optional[str]]:
|
||||
sql = """
|
||||
SELECT visibility FROM history_visibility_ranges
|
||||
WHERE start_range <= ? AND (? < end_range OR end_range IS NULL)
|
||||
AND room_id = ?
|
||||
"""
|
||||
|
||||
results = {}
|
||||
for event in events:
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
event.internal_metadata.stream_ordering,
|
||||
event.internal_metadata.stream_ordering,
|
||||
room_id,
|
||||
),
|
||||
)
|
||||
row = txn.fetchone()
|
||||
if row is not None:
|
||||
results[event.event_id] = row[0]
|
||||
|
||||
return results
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_visibility_for_events", get_visibility_for_events_txn
|
||||
)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True, frozen=True)
|
||||
class PerConnectionStateDB:
|
||||
|
|
|
@ -105,6 +105,9 @@ async def filter_events_for_client(
|
|||
The filtered events. The `unsigned` data is annotated with the membership state
|
||||
of `user_id` at each event.
|
||||
"""
|
||||
if not events:
|
||||
return []
|
||||
|
||||
# Filter out events that have been soft failed so that we don't relay them
|
||||
# to clients.
|
||||
events_before_filtering = events
|
||||
|
@ -117,13 +120,38 @@ async def filter_events_for_client(
|
|||
[event.event_id for event in events],
|
||||
)
|
||||
|
||||
types = (_HISTORY_VIS_KEY, (EventTypes.Member, user_id))
|
||||
types = (
|
||||
_HISTORY_VIS_KEY,
|
||||
(EventTypes.Member, user_id),
|
||||
)
|
||||
|
||||
room_id = events[0].room_id
|
||||
assert all(event.room_id == room_id for event in events)
|
||||
|
||||
visibilities: Dict[str, str] = {}
|
||||
memberships: Dict[str, Optional[EventBase]] = {}
|
||||
events_to_fetch = {e.event_id for e in events if not e.internal_metadata.outlier}
|
||||
if not is_peeking:
|
||||
fetched_visibilities = await storage.main.get_visibility_for_events(
|
||||
room_id, [e for e in events if not e.internal_metadata.outlier]
|
||||
)
|
||||
for event_id, visibility in fetched_visibilities.items():
|
||||
if visibility in (
|
||||
HistoryVisibility.SHARED,
|
||||
HistoryVisibility.WORLD_READABLE,
|
||||
):
|
||||
events_to_fetch.discard(event_id)
|
||||
visibilities[event_id] = visibility
|
||||
|
||||
# we exclude outliers at this point, and then handle them separately later
|
||||
event_id_to_state = await storage.state.get_state_for_events(
|
||||
frozenset(e.event_id for e in events if not e.internal_metadata.outlier),
|
||||
state_filter=StateFilter.from_types(types),
|
||||
)
|
||||
if events_to_fetch:
|
||||
event_id_to_state = await storage.state.get_state_for_events(
|
||||
events_to_fetch,
|
||||
state_filter=StateFilter.from_types(types),
|
||||
)
|
||||
for event_id, state in event_id_to_state.items():
|
||||
visibilities[event_id] = get_effective_room_visibility_from_state(state)
|
||||
memberships[event_id] = state.get((EventTypes.Member, user_id))
|
||||
|
||||
# Get the users who are ignored by the requesting user.
|
||||
ignore_list = await storage.main.ignored_users(user_id)
|
||||
|
@ -140,8 +168,8 @@ async def filter_events_for_client(
|
|||
] = await storage.main.get_retention_policy_for_room(room_id)
|
||||
|
||||
def allowed(event: EventBase) -> Optional[EventBase]:
|
||||
state_after_event = event_id_to_state.get(event.event_id)
|
||||
filtered = _check_client_allowed_to_see_event(
|
||||
# state_after_event = event_id_to_state.get(event.event_id)
|
||||
filtered = _check_client_allowed_to_see_event_with_state(
|
||||
user_id=user_id,
|
||||
event=event,
|
||||
clock=storage.main.clock,
|
||||
|
@ -149,9 +177,10 @@ async def filter_events_for_client(
|
|||
sender_ignored=event.sender in ignore_list,
|
||||
always_include_ids=always_include_ids,
|
||||
retention_policy=retention_policies[event.room_id],
|
||||
state=state_after_event,
|
||||
is_peeking=is_peeking,
|
||||
sender_erased=erased_senders.get(event.sender, False),
|
||||
visibility=visibilities[event.event_id],
|
||||
membership_event=memberships.get(event.event_id),
|
||||
)
|
||||
if filtered is None:
|
||||
return None
|
||||
|
@ -165,11 +194,9 @@ async def filter_events_for_client(
|
|||
user_membership_event: Optional[EventBase]
|
||||
if event.type == EventTypes.Member and event.state_key == user_id:
|
||||
user_membership_event = event
|
||||
elif state_after_event is not None:
|
||||
user_membership_event = state_after_event.get((EventTypes.Member, user_id))
|
||||
else:
|
||||
# unreachable!
|
||||
raise Exception("Missing state for event that is not user's own membership")
|
||||
# TODO: Actually get the proper membership
|
||||
user_membership_event = memberships.get(event_id)
|
||||
|
||||
user_membership = (
|
||||
user_membership_event.membership
|
||||
|
@ -353,6 +380,41 @@ def _check_client_allowed_to_see_event(
|
|||
|
||||
the original event if they can see it as normal.
|
||||
"""
|
||||
|
||||
visibility = HistoryVisibility.SHARED
|
||||
|
||||
if state is not None:
|
||||
visibility = get_effective_room_visibility_from_state(state)
|
||||
membership_event = state.get((EventTypes.Member, user_id)) if state else None
|
||||
|
||||
return _check_client_allowed_to_see_event_with_state(
|
||||
user_id,
|
||||
event,
|
||||
clock,
|
||||
filter_send_to_client,
|
||||
is_peeking,
|
||||
always_include_ids,
|
||||
sender_ignored,
|
||||
retention_policy,
|
||||
sender_erased,
|
||||
visibility=visibility,
|
||||
membership_event=membership_event,
|
||||
)
|
||||
|
||||
|
||||
def _check_client_allowed_to_see_event_with_state(
|
||||
user_id: str,
|
||||
event: EventBase,
|
||||
clock: Clock,
|
||||
filter_send_to_client: bool,
|
||||
is_peeking: bool,
|
||||
always_include_ids: FrozenSet[str],
|
||||
sender_ignored: bool,
|
||||
retention_policy: RetentionPolicy,
|
||||
sender_erased: bool,
|
||||
visibility: str,
|
||||
membership_event: Optional[EventBase],
|
||||
) -> Optional[EventBase]:
|
||||
# Only run some checks if these events aren't about to be sent to clients. This is
|
||||
# because, if this is not the case, we're probably only checking if the users can
|
||||
# see events in the room at that point in the DAG, and that shouldn't be decided
|
||||
|
@ -390,12 +452,6 @@ def _check_client_allowed_to_see_event(
|
|||
)
|
||||
return None
|
||||
|
||||
if state is None:
|
||||
raise Exception("Missing state for non-outlier event")
|
||||
|
||||
# get the room_visibility at the time of the event.
|
||||
visibility = get_effective_room_visibility_from_state(state)
|
||||
|
||||
# Check if the room has lax history visibility, allowing us to skip
|
||||
# membership checks.
|
||||
#
|
||||
|
@ -408,6 +464,10 @@ def _check_client_allowed_to_see_event(
|
|||
):
|
||||
return event
|
||||
|
||||
if membership_event:
|
||||
state = {(EventTypes.Member, user_id): membership_event}
|
||||
else:
|
||||
state = {}
|
||||
membership_result = _check_membership(user_id, event, visibility, state, is_peeking)
|
||||
if not membership_result.allowed:
|
||||
filtered_event_logger.debug(
|
||||
|
|
Loading…
Reference in a new issue