Fetch thread summaries for multiple events in a single query (#11752)

This should reduce database usage when fetching bundled aggregations
as the number of individual queries (and round trips to the database) are
reduced.
This commit is contained in:
Patrick Cloke 2022-02-11 09:50:14 -05:00 committed by GitHub
parent bb98c593a5
commit b65acead42
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 151 additions and 74 deletions

1
changelog.d/11752.misc Normal file
View file

@ -0,0 +1 @@
Improve performance when fetching bundled aggregations for multiple events.

View file

@ -1812,7 +1812,7 @@ class PersistEventsStore:
# potentially error-prone) so it is always invalidated.
txn.call_after(
self.store.get_thread_participated.invalidate,
(parent_id, event.room_id, event.sender),
(parent_id, event.sender),
)
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):

View file

@ -20,6 +20,7 @@ from typing import (
Iterable,
List,
Optional,
Set,
Tuple,
Union,
cast,
@ -454,106 +455,175 @@ class RelationsWorkerStore(SQLBaseStore):
}
@cached()
async def get_thread_summary(
self, event_id: str, room_id: str
) -> Tuple[int, Optional[EventBase]]:
def get_thread_summary(self, event_id: str) -> Optional[Tuple[int, EventBase]]:
raise NotImplementedError()
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
async def _get_thread_summaries(
self, event_ids: Collection[str]
) -> Dict[str, Optional[Tuple[int, EventBase]]]:
"""Get the number of threaded replies and the latest reply (if any) for the given event.
Args:
event_id: Summarize the thread related to this event ID.
room_id: The room the event belongs to.
event_ids: Summarize the thread related to this event ID.
Returns:
The number of items in the thread and the most recent response, if any.
A map of the thread summary each event. A missing event implies there
are no threaded replies.
Each summary includes the number of items in the thread and the most
recent response.
"""
def _get_thread_summary_txn(
def _get_thread_summaries_txn(
txn: LoggingTransaction,
) -> Tuple[int, Optional[str]]:
# Fetch the latest event ID in the thread.
) -> Tuple[Dict[str, int], Dict[str, str]]:
# Fetch the count of threaded events and the latest event ID.
# TODO Should this only allow m.room.message events.
if isinstance(self.database_engine, PostgresEngine):
# The `DISTINCT ON` clause will pick the *first* row it encounters,
# so ordering by topologica ordering + stream ordering desc will
# ensure we get the latest event in the thread.
sql = """
SELECT event_id
FROM event_relations
INNER JOIN events USING (event_id)
SELECT DISTINCT ON (parent.event_id) parent.event_id, child.event_id FROM events AS child
INNER JOIN event_relations USING (event_id)
INNER JOIN events AS parent ON
parent.event_id = relates_to_id
AND parent.room_id = child.room_id
WHERE
relates_to_id = ?
AND room_id = ?
%s
AND relation_type = ?
ORDER BY topological_ordering DESC, stream_ordering DESC
LIMIT 1
ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC
"""
else:
# SQLite uses a simplified query which returns all entries for a
# thread. The first result for each thread is chosen to and subsequent
# results for a thread are ignored.
sql = """
SELECT parent.event_id, child.event_id FROM events AS child
INNER JOIN event_relations USING (event_id)
INNER JOIN events AS parent ON
parent.event_id = relates_to_id
AND parent.room_id = child.room_id
WHERE
%s
AND relation_type = ?
ORDER BY child.topological_ordering DESC, child.stream_ordering DESC
"""
txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
row = txn.fetchone()
if row is None:
return 0, None
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", event_ids
)
args.append(RelationTypes.THREAD)
latest_event_id = row[0]
txn.execute(sql % (clause,), args)
latest_event_ids = {}
for parent_event_id, child_event_id in txn:
# Only consider the latest threaded reply (by topological ordering).
if parent_event_id not in latest_event_ids:
latest_event_ids[parent_event_id] = child_event_id
# If no threads were found, bail.
if not latest_event_ids:
return {}, latest_event_ids
# Fetch the number of threaded replies.
sql = """
SELECT COUNT(event_id)
FROM event_relations
INNER JOIN events USING (event_id)
SELECT parent.event_id, COUNT(child.event_id) FROM events AS child
INNER JOIN event_relations USING (event_id)
INNER JOIN events AS parent ON
parent.event_id = relates_to_id
AND parent.room_id = child.room_id
WHERE
relates_to_id = ?
AND room_id = ?
%s
AND relation_type = ?
GROUP BY parent.event_id
"""
txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
count = cast(Tuple[int], txn.fetchone())[0]
return count, latest_event_id
# Regenerate the arguments since only threads found above could
# possibly have any replies.
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", latest_event_ids.keys()
)
args.append(RelationTypes.THREAD)
count, latest_event_id = await self.db_pool.runInteraction(
"get_thread_summary", _get_thread_summary_txn
txn.execute(sql % (clause,), args)
counts = dict(cast(List[Tuple[str, int]], txn.fetchall()))
return counts, latest_event_ids
counts, latest_event_ids = await self.db_pool.runInteraction(
"get_thread_summaries", _get_thread_summaries_txn
)
latest_event = None
if latest_event_id:
latest_event = await self.get_event(latest_event_id, allow_none=True) # type: ignore[attr-defined]
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
return count, latest_event
# Map to the event IDs to the thread summary.
#
# There might not be a summary due to there not being a thread or
# due to the latest event not being known, either case is treated the same.
summaries = {}
for parent_event_id, latest_event_id in latest_event_ids.items():
latest_event = latest_events.get(latest_event_id)
summary = None
if latest_event:
summary = (counts[parent_event_id], latest_event)
summaries[parent_event_id] = summary
return summaries
@cached()
async def get_thread_participated(
self, event_id: str, room_id: str, user_id: str
) -> bool:
"""Get whether the requesting user participated in a thread.
def get_thread_participated(self, event_id: str, user_id: str) -> bool:
raise NotImplementedError()
This is separate from get_thread_summary since that can be cached across
all users while this value is specific to the requeser.
@cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
async def _get_threads_participated(
self, event_ids: Collection[str], user_id: str
) -> Dict[str, bool]:
"""Get whether the requesting user participated in the given threads.
This is separate from get_thread_summaries since that can be cached across
all users while this value is specific to the requester.
Args:
event_id: The thread related to this event ID.
room_id: The room the event belongs to.
event_ids: The thread related to these event IDs.
user_id: The user requesting the summary.
Returns:
True if the requesting user participated in the thread, otherwise false.
A map of event ID to a boolean which represents if the requesting
user participated in that event's thread, otherwise false.
"""
def _get_thread_summary_txn(txn: LoggingTransaction) -> bool:
def _get_thread_summary_txn(txn: LoggingTransaction) -> Set[str]:
# Fetch whether the requester has participated or not.
sql = """
SELECT 1
FROM event_relations
INNER JOIN events USING (event_id)
SELECT DISTINCT relates_to_id
FROM events AS child
INNER JOIN event_relations USING (event_id)
INNER JOIN events AS parent ON
parent.event_id = relates_to_id
AND parent.room_id = child.room_id
WHERE
relates_to_id = ?
AND room_id = ?
%s
AND relation_type = ?
AND sender = ?
AND child.sender = ?
"""
txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id))
return bool(txn.fetchone())
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", event_ids
)
args.extend((RelationTypes.THREAD, user_id))
return await self.db_pool.runInteraction(
txn.execute(sql % (clause,), args)
return {row[0] for row in txn.fetchall()}
participated_threads = await self.db_pool.runInteraction(
"get_thread_summary", _get_thread_summary_txn
)
return {event_id: event_id in participated_threads for event_id in event_ids}
async def events_have_relations(
self,
parent_ids: List[str],
@ -700,21 +770,6 @@ class RelationsWorkerStore(SQLBaseStore):
if references.chunk:
aggregations.references = await references.to_dict(cast("DataStore", self))
# If this event is the start of a thread, include a summary of the replies.
if self._msc3440_enabled:
thread_count, latest_thread_event = await self.get_thread_summary(
event_id, room_id
)
participated = await self.get_thread_participated(
event_id, room_id, user_id
)
if latest_thread_event:
aggregations.thread = _ThreadAggregation(
latest_event=latest_thread_event,
count=thread_count,
current_user_participated=participated,
)
# Store the bundled aggregations in the event metadata for later use.
return aggregations
@ -763,6 +818,27 @@ class RelationsWorkerStore(SQLBaseStore):
for event_id, edit in edits.items():
results.setdefault(event_id, BundledAggregations()).replace = edit
# Fetch thread summaries.
if self._msc3440_enabled:
summaries = await self._get_thread_summaries(seen_event_ids)
# Only fetch participated for a limited selection based on what had
# summaries.
participated = await self._get_threads_participated(
summaries.keys(), user_id
)
for event_id, summary in summaries.items():
if summary:
thread_count, latest_thread_event = summary
results.setdefault(
event_id, BundledAggregations()
).thread = _ThreadAggregation(
latest_event=latest_thread_event,
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
current_user_participated=participated[event_id],
)
return results