Sliding Sync: Return room tags in account data extension (#17707)

The account data extension was also updated to avoid copies when we pull
the data out of the cache.

Fix https://github.com/element-hq/synapse/issues/17694
This commit is contained in:
Eric Eastwood 2024-09-16 13:47:35 -05:00 committed by GitHub
parent 285de43e48
commit 03937a1cae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 226 additions and 65 deletions

View file

@ -0,0 +1 @@
Return room tags in Sliding Sync account data extension.

View file

@ -14,7 +14,19 @@
import itertools import itertools
import logging import logging
from typing import TYPE_CHECKING, AbstractSet, Dict, Mapping, Optional, Sequence, Set from typing import (
TYPE_CHECKING,
AbstractSet,
ChainMap,
Dict,
List,
Mapping,
MutableMapping,
Optional,
Sequence,
Set,
cast,
)
from typing_extensions import assert_never from typing_extensions import assert_never
@ -381,29 +393,47 @@ class SlidingSyncExtensionHandler:
) )
) )
# TODO: This should take into account the `from_token` and `to_token`
have_push_rules_changed = await self.store.have_push_rules_changed_for_user( have_push_rules_changed = await self.store.have_push_rules_changed_for_user(
user_id, from_token.stream_token.push_rules_key user_id, from_token.stream_token.push_rules_key
) )
if have_push_rules_changed: if have_push_rules_changed:
global_account_data_map = dict(global_account_data_map)
# TODO: This should take into account the `from_token` and `to_token` # TODO: This should take into account the `from_token` and `to_token`
global_account_data_map[ global_account_data_map[
AccountDataTypes.PUSH_RULES AccountDataTypes.PUSH_RULES
] = await self.push_rules_handler.push_rules_for_user(sync_config.user) ] = await self.push_rules_handler.push_rules_for_user(sync_config.user)
else: else:
# TODO: This should take into account the `to_token` # TODO: This should take into account the `to_token`
all_global_account_data = await self.store.get_global_account_data_for_user( immutable_global_account_data_map = (
user_id await self.store.get_global_account_data_for_user(user_id)
) )
global_account_data_map = dict(all_global_account_data) # Use a `ChainMap` to avoid copying the immutable data from the cache
global_account_data_map = ChainMap(
{
# TODO: This should take into account the `to_token` # TODO: This should take into account the `to_token`
global_account_data_map[ AccountDataTypes.PUSH_RULES: await self.push_rules_handler.push_rules_for_user(
AccountDataTypes.PUSH_RULES sync_config.user
] = await self.push_rules_handler.push_rules_for_user(sync_config.user) )
},
# Cast is safe because `ChainMap` only mutates the top-most map,
# see https://github.com/python/typeshed/issues/8430
cast(
MutableMapping[str, JsonMapping], immutable_global_account_data_map
),
)
# Fetch room account data # Fetch room account data
account_data_by_room_map: Mapping[str, Mapping[str, JsonMapping]] = {} #
# List of -> Mapping from room_id to mapping of `type` to `content` of room
# account data events.
#
# This is is a list so we can avoid making copies of immutable data and instead
# just provide multiple maps that need to be combined. Normally, we could
# reach for `ChainMap` in this scenario, but this is a nested map and accessing
# the ChainMap by room_id won't combine the two maps for that room (we would
# need a new `NestedChainMap` type class).
account_data_by_room_maps: List[Mapping[str, Mapping[str, JsonMapping]]] = []
relevant_room_ids = self.find_relevant_room_ids_for_extension( relevant_room_ids = self.find_relevant_room_ids_for_extension(
requested_lists=account_data_request.lists, requested_lists=account_data_request.lists,
requested_room_ids=account_data_request.rooms, requested_room_ids=account_data_request.rooms,
@ -418,22 +448,66 @@ class SlidingSyncExtensionHandler:
user_id, from_token.stream_token.account_data_key user_id, from_token.stream_token.account_data_key
) )
) )
# Add room tags
#
# TODO: This should take into account the `from_token` and `to_token`
tags_by_room = await self.store.get_updated_tags(
user_id, from_token.stream_token.account_data_key
)
for room_id, tags in tags_by_room.items():
account_data_by_room_map.setdefault(room_id, {})[
AccountDataTypes.TAG
] = {"tags": tags}
account_data_by_room_maps.append(account_data_by_room_map)
else: else:
# TODO: This should take into account the `to_token` # TODO: This should take into account the `to_token`
account_data_by_room_map = ( immutable_account_data_by_room_map = (
await self.store.get_room_account_data_for_user(user_id) await self.store.get_room_account_data_for_user(user_id)
) )
account_data_by_room_maps.append(immutable_account_data_by_room_map)
# Filter down to the relevant rooms # Add room tags
account_data_by_room_map = { #
room_id: account_data_map # TODO: This should take into account the `to_token`
for room_id, account_data_map in account_data_by_room_map.items() tags_by_room = await self.store.get_tags_for_user(user_id)
if room_id in relevant_room_ids account_data_by_room_maps.append(
{
room_id: {AccountDataTypes.TAG: {"tags": tags}}
for room_id, tags in tags_by_room.items()
} }
)
# Filter down to the relevant rooms ... and combine the maps
relevant_account_data_by_room_map: MutableMapping[
str, Mapping[str, JsonMapping]
] = {}
for room_id in relevant_room_ids:
# We want to avoid adding empty maps for relevant rooms that have no room
# account data so do a quick check to see if it's in any of the maps.
is_room_in_maps = False
for room_map in account_data_by_room_maps:
if room_id in room_map:
is_room_in_maps = True
break
# If we found the room in any of the maps, combine the maps for that room
if is_room_in_maps:
relevant_account_data_by_room_map[room_id] = ChainMap(
{},
*(
# Cast is safe because `ChainMap` only mutates the top-most map,
# see https://github.com/python/typeshed/issues/8430
cast(MutableMapping[str, JsonMapping], room_map[room_id])
for room_map in account_data_by_room_maps
if room_map.get(room_id)
),
)
return SlidingSyncResult.Extensions.AccountDataExtension( return SlidingSyncResult.Extensions.AccountDataExtension(
global_account_data_map=global_account_data_map, global_account_data_map=global_account_data_map,
account_data_by_room_map=account_data_by_room_map, account_data_by_room_map=relevant_account_data_by_room_map,
) )
@trace @trace

View file

@ -177,7 +177,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
def get_room_account_data_for_user_txn( def get_room_account_data_for_user_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> Dict[str, Dict[str, JsonDict]]: ) -> Dict[str, Dict[str, JsonMapping]]:
# The 'content != '{}' condition below prevents us from using # The 'content != '{}' condition below prevents us from using
# `simple_select_list_txn` here, as it doesn't support conditions # `simple_select_list_txn` here, as it doesn't support conditions
# other than 'equals'. # other than 'equals'.
@ -194,7 +194,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
by_room: Dict[str, Dict[str, JsonDict]] = {} by_room: Dict[str, Dict[str, JsonMapping]] = {}
for room_id, account_data_type, content in txn: for room_id, account_data_type, content in txn:
room_data = by_room.setdefault(room_id, {}) room_data = by_room.setdefault(room_id, {})
@ -394,7 +394,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
async def get_updated_global_account_data_for_user( async def get_updated_global_account_data_for_user(
self, user_id: str, stream_id: int self, user_id: str, stream_id: int
) -> Mapping[str, JsonMapping]: ) -> Dict[str, JsonMapping]:
"""Get all the global account_data that's changed for a user. """Get all the global account_data that's changed for a user.
Args: Args:
@ -407,7 +407,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
def get_updated_global_account_data_for_user( def get_updated_global_account_data_for_user(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> Dict[str, JsonDict]: ) -> Dict[str, JsonMapping]:
sql = """ sql = """
SELECT account_data_type, content FROM account_data SELECT account_data_type, content FROM account_data
WHERE user_id = ? AND stream_id > ? WHERE user_id = ? AND stream_id > ?
@ -429,7 +429,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
async def get_updated_room_account_data_for_user( async def get_updated_room_account_data_for_user(
self, user_id: str, stream_id: int self, user_id: str, stream_id: int
) -> Dict[str, Dict[str, JsonDict]]: ) -> Dict[str, Dict[str, JsonMapping]]:
"""Get all the room account_data that's changed for a user. """Get all the room account_data that's changed for a user.
Args: Args:
@ -442,14 +442,14 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
def get_updated_room_account_data_for_user_txn( def get_updated_room_account_data_for_user_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> Dict[str, Dict[str, JsonDict]]: ) -> Dict[str, Dict[str, JsonMapping]]:
sql = """ sql = """
SELECT room_id, account_data_type, content FROM room_account_data SELECT room_id, account_data_type, content FROM room_account_data
WHERE user_id = ? AND stream_id > ? WHERE user_id = ? AND stream_id > ?
""" """
txn.execute(sql, (user_id, stream_id)) txn.execute(sql, (user_id, stream_id))
account_data_by_room: Dict[str, Dict[str, JsonDict]] = {} account_data_by_room: Dict[str, Dict[str, JsonMapping]] = {}
for row in txn: for row in txn:
room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data = account_data_by_room.setdefault(row[0], {})
room_account_data[row[1]] = db_to_json(row[2]) room_account_data[row[1]] = db_to_json(row[2])

View file

@ -314,8 +314,8 @@ class SlidingSyncResult:
"""The Account Data extension (MSC3959) """The Account Data extension (MSC3959)
Attributes: Attributes:
global_account_data_map: Mapping from `type` to `content` of global account global_account_data_map: Mapping from `type` to `content` of global
data events. account data events.
account_data_by_room_map: Mapping from room_id to mapping of `type` to account_data_by_room_map: Mapping from room_id to mapping of `type` to
`content` of room account data events. `content` of room account data events.
""" """

View file

@ -80,18 +80,23 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
} }
response_body, _ = self.do_sync(sync_body, tok=user1_tok) response_body, _ = self.do_sync(sync_body, tok=user1_tok)
self.assertIncludes( global_account_data_map = {
{ global_event["type"]: global_event["content"]
global_event["type"]
for global_event in response_body["extensions"]["account_data"].get( for global_event in response_body["extensions"]["account_data"].get(
"global" "global"
) )
}, }
self.assertIncludes(
global_account_data_map.keys(),
# Even though we don't have any global account data set, Synapse saves some # Even though we don't have any global account data set, Synapse saves some
# default push rules for us. # default push rules for us.
{AccountDataTypes.PUSH_RULES}, {AccountDataTypes.PUSH_RULES},
exact=True, exact=True,
) )
# Push rules are a giant chunk of JSON data so we will just assume the value is correct if they key is here.
# global_account_data_map[AccountDataTypes.PUSH_RULES]
# No room account data for this test
self.assertIncludes( self.assertIncludes(
response_body["extensions"]["account_data"].get("rooms").keys(), response_body["extensions"]["account_data"].get("rooms").keys(),
set(), set(),
@ -121,16 +126,19 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
# There has been no account data changes since the `from_token` so we shouldn't # There has been no account data changes since the `from_token` so we shouldn't
# see any account data here. # see any account data here.
self.assertIncludes( global_account_data_map = {
{ global_event["type"]: global_event["content"]
global_event["type"]
for global_event in response_body["extensions"]["account_data"].get( for global_event in response_body["extensions"]["account_data"].get(
"global" "global"
) )
}, }
self.assertIncludes(
global_account_data_map.keys(),
set(), set(),
exact=True, exact=True,
) )
# No room account data for this test
self.assertIncludes( self.assertIncludes(
response_body["extensions"]["account_data"].get("rooms").keys(), response_body["extensions"]["account_data"].get("rooms").keys(),
set(), set(),
@ -165,16 +173,24 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
response_body, _ = self.do_sync(sync_body, tok=user1_tok) response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# It should show us all of the global account data # It should show us all of the global account data
self.assertIncludes( global_account_data_map = {
{ global_event["type"]: global_event["content"]
global_event["type"]
for global_event in response_body["extensions"]["account_data"].get( for global_event in response_body["extensions"]["account_data"].get(
"global" "global"
) )
}, }
self.assertIncludes(
global_account_data_map.keys(),
{AccountDataTypes.PUSH_RULES, "org.matrix.foobarbaz"}, {AccountDataTypes.PUSH_RULES, "org.matrix.foobarbaz"},
exact=True, exact=True,
) )
# Push rules are a giant chunk of JSON data so we will just assume the value is correct if they key is here.
# global_account_data_map[AccountDataTypes.PUSH_RULES]
self.assertEqual(
global_account_data_map["org.matrix.foobarbaz"], {"foo": "bar"}
)
# No room account data for this test
self.assertIncludes( self.assertIncludes(
response_body["extensions"]["account_data"].get("rooms").keys(), response_body["extensions"]["account_data"].get("rooms").keys(),
set(), set(),
@ -220,17 +236,23 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
# Make an incremental Sliding Sync request with the account_data extension enabled # Make an incremental Sliding Sync request with the account_data extension enabled
response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
self.assertIncludes( global_account_data_map = {
{ global_event["type"]: global_event["content"]
global_event["type"]
for global_event in response_body["extensions"]["account_data"].get( for global_event in response_body["extensions"]["account_data"].get(
"global" "global"
) )
}, }
self.assertIncludes(
global_account_data_map.keys(),
# We should only see the new global account data that happened after the `from_token` # We should only see the new global account data that happened after the `from_token`
{"org.matrix.doodardaz"}, {"org.matrix.doodardaz"},
exact=True, exact=True,
) )
self.assertEqual(
global_account_data_map["org.matrix.doodardaz"], {"doo": "dar"}
)
# No room account data for this test
self.assertIncludes( self.assertIncludes(
response_body["extensions"]["account_data"].get("rooms").keys(), response_body["extensions"]["account_data"].get("rooms").keys(),
set(), set(),
@ -255,6 +277,15 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
content={"roo": "rar"}, content={"roo": "rar"},
) )
) )
# Add a room tag to mark the room as a favourite
self.get_success(
self.account_data_handler.add_tag_to_room(
user_id=user1_id,
room_id=room_id1,
tag="m.favourite",
content={},
)
)
# Create another room with some room account data # Create another room with some room account data
room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok) room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok)
@ -266,6 +297,15 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
content={"roo": "rar"}, content={"roo": "rar"},
) )
) )
# Add a room tag to mark the room as a favourite
self.get_success(
self.account_data_handler.add_tag_to_room(
user_id=user1_id,
room_id=room_id2,
tag="m.favourite",
content={},
)
)
# Make an initial Sliding Sync request with the account_data extension enabled # Make an initial Sliding Sync request with the account_data extension enabled
sync_body = { sync_body = {
@ -294,16 +334,21 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
{room_id1}, {room_id1},
exact=True, exact=True,
) )
self.assertIncludes( account_data_map = {
{ event["type"]: event["content"]
event["type"]
for event in response_body["extensions"]["account_data"] for event in response_body["extensions"]["account_data"]
.get("rooms") .get("rooms")
.get(room_id1) .get(room_id1)
}, }
{"org.matrix.roorarraz"}, self.assertIncludes(
account_data_map.keys(),
{"org.matrix.roorarraz", AccountDataTypes.TAG},
exact=True, exact=True,
) )
self.assertEqual(account_data_map["org.matrix.roorarraz"], {"roo": "rar"})
self.assertEqual(
account_data_map[AccountDataTypes.TAG], {"tags": {"m.favourite": {}}}
)
def test_room_account_data_incremental_sync(self) -> None: def test_room_account_data_incremental_sync(self) -> None:
""" """
@ -323,6 +368,15 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
content={"roo": "rar"}, content={"roo": "rar"},
) )
) )
# Add a room tag to mark the room as a favourite
self.get_success(
self.account_data_handler.add_tag_to_room(
user_id=user1_id,
room_id=room_id1,
tag="m.favourite",
content={},
)
)
# Create another room with some room account data # Create another room with some room account data
room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok) room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok)
@ -334,6 +388,15 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
content={"roo": "rar"}, content={"roo": "rar"},
) )
) )
# Add a room tag to mark the room as a favourite
self.get_success(
self.account_data_handler.add_tag_to_room(
user_id=user1_id,
room_id=room_id2,
tag="m.favourite",
content={},
)
)
sync_body = { sync_body = {
"lists": {}, "lists": {},
@ -369,6 +432,23 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
content={"roo": "rar"}, content={"roo": "rar"},
) )
) )
# Add another room tag
self.get_success(
self.account_data_handler.add_tag_to_room(
user_id=user1_id,
room_id=room_id1,
tag="m.server_notice",
content={},
)
)
self.get_success(
self.account_data_handler.add_tag_to_room(
user_id=user1_id,
room_id=room_id2,
tag="m.server_notice",
content={},
)
)
# Make an incremental Sliding Sync request with the account_data extension enabled # Make an incremental Sliding Sync request with the account_data extension enabled
response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
@ -383,16 +463,22 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
exact=True, exact=True,
) )
# We should only see the new room account data that happened after the `from_token` # We should only see the new room account data that happened after the `from_token`
self.assertIncludes( account_data_map = {
{ event["type"]: event["content"]
event["type"]
for event in response_body["extensions"]["account_data"] for event in response_body["extensions"]["account_data"]
.get("rooms") .get("rooms")
.get(room_id1) .get(room_id1)
}, }
{"org.matrix.roorarraz2"}, self.assertIncludes(
account_data_map.keys(),
{"org.matrix.roorarraz2", AccountDataTypes.TAG},
exact=True, exact=True,
) )
self.assertEqual(account_data_map["org.matrix.roorarraz2"], {"roo": "rar"})
self.assertEqual(
account_data_map[AccountDataTypes.TAG],
{"tags": {"m.favourite": {}, "m.server_notice": {}}},
)
def test_wait_for_new_data(self) -> None: def test_wait_for_new_data(self) -> None:
""" """