Sliding Sync: Make sure we get up-to-date information from get_sliding_sync_rooms_for_user(...) (#17692)

We need to bust the `get_sliding_sync_rooms_for_user`
cache when the room encryption is updated and any
other field that is used in the query.

Follow-up to https://github.com/element-hq/synapse/pull/17630

- Bust cache for membership change (cross-reference
`get_rooms_for_user`)
- Bust cache for room `encryption` (cross-reference
`get_room_encryption`)
- Bust cache for `forgotten` (cross-reference
`did_forget`/`get_forgotten_rooms_for_user`)
This commit is contained in:
Eric Eastwood 2024-09-11 12:13:54 -05:00 committed by GitHub
parent 6b131a99fe
commit e4a1f271b9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 160 additions and 56 deletions

1
changelog.d/17692.bugfix Normal file
View file

@ -0,0 +1 @@
Make sure we get up-to-date state information when using the new Sliding Sync tables to derive room membership.

View file

@ -136,6 +136,7 @@ class SQLBaseStore(metaclass=ABCMeta):
self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,)) self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,))
self._attempt_to_invalidate_cache("get_room_type", (room_id,)) self._attempt_to_invalidate_cache("get_room_type", (room_id,))
self._attempt_to_invalidate_cache("get_room_encryption", (room_id,)) self._attempt_to_invalidate_cache("get_room_encryption", (room_id,))
self._attempt_to_invalidate_cache("get_sliding_sync_rooms_for_user", None)
def _invalidate_state_caches_all(self, room_id: str) -> None: def _invalidate_state_caches_all(self, room_id: str) -> None:
"""Invalidates caches that are based on the current state, but does """Invalidates caches that are based on the current state, but does

View file

@ -41,6 +41,7 @@ from synapse.storage.database import (
LoggingDatabaseConnection, LoggingDatabaseConnection,
LoggingTransaction, LoggingTransaction,
) )
from synapse.storage.databases.main.events import SLIDING_SYNC_RELEVANT_STATE_SET
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.util.caches.descriptors import CachedFunction from synapse.util.caches.descriptors import CachedFunction
@ -271,12 +272,20 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache( self._attempt_to_invalidate_cache(
"get_rooms_for_user", (data.state_key,) "get_rooms_for_user", (data.state_key,)
) )
self._attempt_to_invalidate_cache(
"get_sliding_sync_rooms_for_user", None
)
elif data.type == EventTypes.RoomEncryption: elif data.type == EventTypes.RoomEncryption:
self._attempt_to_invalidate_cache( self._attempt_to_invalidate_cache(
"get_room_encryption", (data.room_id,) "get_room_encryption", (data.room_id,)
) )
elif data.type == EventTypes.Create: elif data.type == EventTypes.Create:
self._attempt_to_invalidate_cache("get_room_type", (data.room_id,)) self._attempt_to_invalidate_cache("get_room_type", (data.room_id,))
if (data.type, data.state_key) in SLIDING_SYNC_RELEVANT_STATE_SET:
self._attempt_to_invalidate_cache(
"get_sliding_sync_rooms_for_user", None
)
elif row.type == EventsStreamAllStateRow.TypeId: elif row.type == EventsStreamAllStateRow.TypeId:
assert isinstance(data, EventsStreamAllStateRow) assert isinstance(data, EventsStreamAllStateRow)
# Similar to the above, but the entire caches are invalidated. This is # Similar to the above, but the entire caches are invalidated. This is
@ -285,6 +294,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache("get_rooms_for_user", None) self._attempt_to_invalidate_cache("get_rooms_for_user", None)
self._attempt_to_invalidate_cache("get_room_type", (data.room_id,)) self._attempt_to_invalidate_cache("get_room_type", (data.room_id,))
self._attempt_to_invalidate_cache("get_room_encryption", (data.room_id,)) self._attempt_to_invalidate_cache("get_room_encryption", (data.room_id,))
self._attempt_to_invalidate_cache("get_sliding_sync_rooms_for_user", None)
else: else:
raise Exception("Unknown events stream row type %s" % (row.type,)) raise Exception("Unknown events stream row type %s" % (row.type,))
@ -365,6 +375,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
elif etype == EventTypes.RoomEncryption: elif etype == EventTypes.RoomEncryption:
self._attempt_to_invalidate_cache("get_room_encryption", (room_id,)) self._attempt_to_invalidate_cache("get_room_encryption", (room_id,))
if (etype, state_key) in SLIDING_SYNC_RELEVANT_STATE_SET:
self._attempt_to_invalidate_cache("get_sliding_sync_rooms_for_user", None)
if relates_to: if relates_to:
self._attempt_to_invalidate_cache( self._attempt_to_invalidate_cache(
"get_relations_for_event", "get_relations_for_event",
@ -477,6 +490,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache( self._attempt_to_invalidate_cache(
"get_current_hosts_in_room_ordered", (room_id,) "get_current_hosts_in_room_ordered", (room_id,)
) )
self._attempt_to_invalidate_cache("get_sliding_sync_rooms_for_user", None)
self._attempt_to_invalidate_cache("did_forget", None) self._attempt_to_invalidate_cache("did_forget", None)
self._attempt_to_invalidate_cache("get_forgotten_rooms_for_user", None) self._attempt_to_invalidate_cache("get_forgotten_rooms_for_user", None)
self._attempt_to_invalidate_cache("_get_membership_from_event_id", None) self._attempt_to_invalidate_cache("_get_membership_from_event_id", None)

View file

@ -1365,6 +1365,9 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.get_forgotten_rooms_for_user, (user_id,) txn, self.get_forgotten_rooms_for_user, (user_id,)
) )
self._invalidate_cache_and_stream(
txn, self.get_sliding_sync_rooms_for_user, (user_id,)
)
await self.db_pool.runInteraction("forget_membership", f) await self.db_pool.runInteraction("forget_membership", f)
@ -1410,6 +1413,10 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
def get_sliding_sync_rooms_for_user_txn( def get_sliding_sync_rooms_for_user_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> Dict[str, RoomsForUserSlidingSync]: ) -> Dict[str, RoomsForUserSlidingSync]:
# XXX: If you use any new columns that can change (like from
# `sliding_sync_joined_rooms` or `forgotten`), make sure to bust the
# `get_sliding_sync_rooms_for_user` cache in the appropriate places (and add
# tests).
sql = """ sql = """
SELECT m.room_id, m.sender, m.membership, m.membership_event_id, SELECT m.room_id, m.sender, m.membership, m.membership_event_id,
r.room_version, r.room_version,
@ -1432,7 +1439,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
room_version_id=row[4], room_version_id=row[4],
event_pos=PersistedEventPosition(row[5], row[6]), event_pos=PersistedEventPosition(row[5], row[6]),
room_type=row[7], room_type=row[7],
is_encrypted=row[8], is_encrypted=bool(row[8]),
) )
for row in txn for row in txn
} }

View file

@ -722,10 +722,7 @@ class SlidingSyncTestCase(SlidingSyncBase):
self.helper.join(space_room_id, user1_id, tok=user1_tok) self.helper.join(space_room_id, user1_id, tok=user1_tok)
# Make an initial Sliding Sync request # Make an initial Sliding Sync request
channel = self.make_request( sync_body = {
"POST",
self.sync_endpoint,
{
"lists": { "lists": {
"all-list": { "all-list": {
"ranges": [[0, 99]], "ranges": [[0, 99]],
@ -743,22 +740,19 @@ class SlidingSyncTestCase(SlidingSyncBase):
}, },
}, },
} }
}, }
access_token=user1_tok, response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
)
self.assertEqual(channel.code, 200, channel.json_body)
from_token = channel.json_body["pos"]
# Make sure the response has the lists we requested # Make sure the response has the lists we requested
self.assertListEqual( self.assertListEqual(
list(channel.json_body["lists"].keys()), list(response_body["lists"].keys()),
["all-list", "foo-list"], ["all-list", "foo-list"],
channel.json_body["lists"].keys(), response_body["lists"].keys(),
) )
# Make sure the lists have the correct rooms # Make sure the lists have the correct rooms
self.assertListEqual( self.assertListEqual(
list(channel.json_body["lists"]["all-list"]["ops"]), list(response_body["lists"]["all-list"]["ops"]),
[ [
{ {
"op": "SYNC", "op": "SYNC",
@ -768,7 +762,7 @@ class SlidingSyncTestCase(SlidingSyncBase):
], ],
) )
self.assertListEqual( self.assertListEqual(
list(channel.json_body["lists"]["foo-list"]["ops"]), list(response_body["lists"]["foo-list"]["ops"]),
[ [
{ {
"op": "SYNC", "op": "SYNC",
@ -783,10 +777,7 @@ class SlidingSyncTestCase(SlidingSyncBase):
self.helper.leave(space_room_id, user2_id, tok=user2_tok) self.helper.leave(space_room_id, user2_id, tok=user2_tok)
# Make an incremental Sliding Sync request # Make an incremental Sliding Sync request
channel = self.make_request( sync_body = {
"POST",
self.sync_endpoint + f"?pos={from_token}",
{
"lists": { "lists": {
"all-list": { "all-list": {
"ranges": [[0, 99]], "ranges": [[0, 99]],
@ -804,14 +795,12 @@ class SlidingSyncTestCase(SlidingSyncBase):
}, },
}, },
} }
}, }
access_token=user1_tok, response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
)
self.assertEqual(channel.code, 200, channel.json_body)
# Make sure the lists have the correct rooms even though we `newly_left` # Make sure the lists have the correct rooms even though we `newly_left`
self.assertListEqual( self.assertListEqual(
list(channel.json_body["lists"]["all-list"]["ops"]), list(response_body["lists"]["all-list"]["ops"]),
[ [
{ {
"op": "SYNC", "op": "SYNC",
@ -821,7 +810,7 @@ class SlidingSyncTestCase(SlidingSyncBase):
], ],
) )
self.assertListEqual( self.assertListEqual(
list(channel.json_body["lists"]["foo-list"]["ops"]), list(response_body["lists"]["foo-list"]["ops"]),
[ [
{ {
"op": "SYNC", "op": "SYNC",
@ -831,6 +820,98 @@ class SlidingSyncTestCase(SlidingSyncBase):
], ],
) )
def test_filter_is_encrypted_up_to_date(self) -> None:
"""
Make sure we get up-to-date `is_encrypted` status for a joined room
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
sync_body = {
"lists": {
"foo-list": {
"ranges": [[0, 99]],
"required_state": [],
"timeline_limit": 0,
"filters": {
"is_encrypted": True,
},
},
}
}
response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
self.assertIncludes(
set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
set(),
exact=True,
)
# Update the encryption status
self.helper.send_state(
room_id,
EventTypes.RoomEncryption,
{EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
tok=user1_tok,
)
# We should see the room now because it's encrypted
response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
self.assertIncludes(
set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
{room_id},
exact=True,
)
def test_forgotten_up_to_date(self) -> None:
"""
Make sure we get up-to-date `forgotten` status for rooms
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
# User1 is banned from the room (was never in the room)
self.helper.ban(room_id, src=user2_id, targ=user1_id, tok=user2_tok)
sync_body = {
"lists": {
"foo-list": {
"ranges": [[0, 99]],
"required_state": [],
"timeline_limit": 0,
"filters": {},
},
}
}
response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
self.assertIncludes(
set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
{room_id},
exact=True,
)
# User1 forgets the room
channel = self.make_request(
"POST",
f"/_matrix/client/r0/rooms/{room_id}/forget",
content={},
access_token=user1_tok,
)
self.assertEqual(channel.code, 200, channel.result)
# We should no longer see the forgotten room
response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
self.assertIncludes(
set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
set(),
exact=True,
)
def test_sort_list(self) -> None: def test_sort_list(self) -> None:
""" """
Test that the `lists` are sorted by `stream_ordering` Test that the `lists` are sorted by `stream_ordering`