Include whether the requesting user has participated in a thread. (#11577)

Per updates to MSC3440.

This is implement as a separate method since it needs to be cached
on a per-user basis, instead of a per-thread basis.
This commit is contained in:
Patrick Cloke 2022-01-18 11:38:57 -05:00 committed by GitHub
parent 251b5567ec
commit 68acb0a29d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 86 additions and 19 deletions

View file

@ -0,0 +1 @@
Include whether the requesting user has participated in a thread when generating a summary for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440).

View file

@ -537,7 +537,7 @@ class PaginationHandler:
state_dict = await self.store.get_events(list(state_ids.values()))
state = state_dict.values()
aggregations = await self.store.get_bundled_aggregations(events)
aggregations = await self.store.get_bundled_aggregations(events, user_id)
time_now = self.clock.time_msec()

View file

@ -1182,12 +1182,18 @@ class RoomContextHandler:
results["event"] = filtered[0]
# Fetch the aggregations.
aggregations = await self.store.get_bundled_aggregations([results["event"]])
aggregations.update(
await self.store.get_bundled_aggregations(results["events_before"])
aggregations = await self.store.get_bundled_aggregations(
[results["event"]], user.to_string()
)
aggregations.update(
await self.store.get_bundled_aggregations(results["events_after"])
await self.store.get_bundled_aggregations(
results["events_before"], user.to_string()
)
)
aggregations.update(
await self.store.get_bundled_aggregations(
results["events_after"], user.to_string()
)
)
results["aggregations"] = aggregations

View file

@ -637,7 +637,9 @@ class SyncHandler:
# as clients will have all the necessary information.
bundled_aggregations = None
if limited or newly_joined_room:
bundled_aggregations = await self.store.get_bundled_aggregations(recents)
bundled_aggregations = await self.store.get_bundled_aggregations(
recents, sync_config.user.to_string()
)
return TimelineBatch(
events=recents,

View file

@ -118,7 +118,9 @@ class RelationPaginationServlet(RestServlet):
)
# The relations returned for the requested event do include their
# bundled aggregations.
aggregations = await self.store.get_bundled_aggregations(events)
aggregations = await self.store.get_bundled_aggregations(
events, requester.user.to_string()
)
serialized_events = self._event_serializer.serialize_events(
events, now, bundle_aggregations=aggregations
)

View file

@ -663,7 +663,9 @@ class RoomEventServlet(RestServlet):
if event:
# Ensure there are bundled aggregations available.
aggregations = await self._store.get_bundled_aggregations([event])
aggregations = await self._store.get_bundled_aggregations(
[event], requester.user.to_string()
)
time_now = self.clock.time_msec()
event_dict = self._event_serializer.serialize_event(

View file

@ -1793,6 +1793,13 @@ class PersistEventsStore:
txn.call_after(
self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
)
# It should be safe to only invalidate the cache if the user has not
# previously participated in the thread, but that's difficult (and
# potentially error-prone) so it is always invalidated.
txn.call_after(
self.store.get_thread_participated.invalidate,
(parent_id, event.room_id, event.sender),
)
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
"""Handles keeping track of insertion events and edges/connections.

View file

@ -384,8 +384,7 @@ class RelationsWorkerStore(SQLBaseStore):
async def get_thread_summary(
self, event_id: str, room_id: str
) -> Tuple[int, Optional[EventBase]]:
"""Get the number of threaded replies, the senders of those replies, and
the latest reply (if any) for the given event.
"""Get the number of threaded replies and the latest reply (if any) for the given event.
Args:
event_id: Summarize the thread related to this event ID.
@ -398,7 +397,7 @@ class RelationsWorkerStore(SQLBaseStore):
def _get_thread_summary_txn(
txn: LoggingTransaction,
) -> Tuple[int, Optional[str]]:
# Fetch the count of threaded events and the latest event ID.
# Fetch the latest event ID in the thread.
# TODO Should this only allow m.room.message events.
sql = """
SELECT event_id
@ -419,6 +418,7 @@ class RelationsWorkerStore(SQLBaseStore):
latest_event_id = row[0]
# Fetch the number of threaded replies.
sql = """
SELECT COUNT(event_id)
FROM event_relations
@ -443,6 +443,44 @@ class RelationsWorkerStore(SQLBaseStore):
return count, latest_event
@cached()
async def get_thread_participated(
self, event_id: str, room_id: str, user_id: str
) -> bool:
"""Get whether the requesting user participated in a thread.
This is separate from get_thread_summary since that can be cached across
all users while this value is specific to the requeser.
Args:
event_id: The thread related to this event ID.
room_id: The room the event belongs to.
user_id: The user requesting the summary.
Returns:
True if the requesting user participated in the thread, otherwise false.
"""
def _get_thread_summary_txn(txn: LoggingTransaction) -> bool:
# Fetch whether the requester has participated or not.
sql = """
SELECT 1
FROM event_relations
INNER JOIN events USING (event_id)
WHERE
relates_to_id = ?
AND room_id = ?
AND relation_type = ?
AND sender = ?
"""
txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id))
return bool(txn.fetchone())
return await self.db_pool.runInteraction(
"get_thread_summary", _get_thread_summary_txn
)
async def events_have_relations(
self,
parent_ids: List[str],
@ -546,7 +584,7 @@ class RelationsWorkerStore(SQLBaseStore):
)
async def _get_bundled_aggregation_for_event(
self, event: EventBase
self, event: EventBase, user_id: str
) -> Optional[Dict[str, Any]]:
"""Generate bundled aggregations for an event.
@ -554,6 +592,7 @@ class RelationsWorkerStore(SQLBaseStore):
Args:
event: The event to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
Returns:
The bundled aggregations for an event, if bundled aggregations are
@ -598,27 +637,32 @@ class RelationsWorkerStore(SQLBaseStore):
# If this event is the start of a thread, include a summary of the replies.
if self._msc3440_enabled:
(
thread_count,
latest_thread_event,
) = await self.get_thread_summary(event_id, room_id)
thread_count, latest_thread_event = await self.get_thread_summary(
event_id, room_id
)
participated = await self.get_thread_participated(
event_id, room_id, user_id
)
if latest_thread_event:
aggregations[RelationTypes.THREAD] = {
# Don't bundle aggregations as this could recurse forever.
"latest_event": latest_thread_event,
"count": thread_count,
"current_user_participated": participated,
}
# Store the bundled aggregations in the event metadata for later use.
return aggregations
async def get_bundled_aggregations(
self, events: Iterable[EventBase]
self,
events: Iterable[EventBase],
user_id: str,
) -> Dict[str, Dict[str, Any]]:
"""Generate bundled aggregations for events.
Args:
events: The iterable of events to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
Returns:
A map of event ID to the bundled aggregation for the event. Not all
@ -631,7 +675,7 @@ class RelationsWorkerStore(SQLBaseStore):
# TODO Parallelize.
results = {}
for event in events:
event_result = await self._get_bundled_aggregation_for_event(event)
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
if event_result is not None:
results[event.event_id] = event_result

View file

@ -515,6 +515,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
2,
actual[RelationTypes.THREAD].get("count"),
)
self.assertTrue(
actual[RelationTypes.THREAD].get("current_user_participated")
)
# The latest thread event has some fields that don't matter.
self.assert_dict(
{