Reduce state pulled from DB due to sending typing and receipts over federation (#12964)

Reducing the amount of state we pull from the DB is useful as fetching state is expensive in terms of DB, CPU and memory.
This commit is contained in:
Erik Johnston 2022-06-06 18:46:11 +03:00 committed by GitHub
parent 148fe58a24
commit 44de53bb79
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 68 additions and 16 deletions

1
changelog.d/12964.misc Normal file
View file

@ -0,0 +1 @@
Reduce the amount of state we pull from the DB.

View file

@ -245,6 +245,8 @@ class FederationSender(AbstractFederationSender):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self._storage_controllers = hs.get_storage_controllers()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
@ -602,7 +604,9 @@ class FederationSender(AbstractFederationSender):
room_id = receipt.room_id room_id = receipt.room_id
# Work out which remote servers should be poked and poke them. # Work out which remote servers should be poked and poke them.
domains_set = await self.state.get_current_hosts_in_room(room_id) domains_set = await self._storage_controllers.state.get_current_hosts_in_room(
room_id
)
domains = [ domains = [
d d
for d in domains_set for d in domains_set

View file

@ -59,6 +59,7 @@ class FollowerTypingHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self.server_name = hs.config.server.server_name self.server_name = hs.config.server.server_name
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
@ -131,7 +132,6 @@ class FollowerTypingHandler:
return return
try: try:
users = await self.store.get_users_in_room(member.room_id)
self._member_last_federation_poke[member] = self.clock.time_msec() self._member_last_federation_poke[member] = self.clock.time_msec()
now = self.clock.time_msec() now = self.clock.time_msec()
@ -139,7 +139,10 @@ class FollowerTypingHandler:
now=now, obj=member, then=now + FEDERATION_PING_INTERVAL now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
) )
for domain in {get_domain_from_id(u) for u in users}: hosts = await self._storage_controllers.state.get_current_hosts_in_room(
member.room_id
)
for domain in hosts:
if domain != self.server_name: if domain != self.server_name:
logger.debug("sending typing update to %s", domain) logger.debug("sending typing update to %s", domain)
self.federation.build_and_send_edu( self.federation.build_and_send_edu(

View file

@ -172,10 +172,6 @@ class StateHandler:
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return await self.store.get_joined_users_from_state(room_id, entry) return await self.store.get_joined_users_from_state(room_id, entry)
async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]:
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
return await self.get_hosts_in_room_at_events(room_id, event_ids)
async def get_hosts_in_room_at_events( async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str] self, room_id: str, event_ids: Collection[str]
) -> FrozenSet[str]: ) -> FrozenSet[str]:

View file

@ -71,6 +71,7 @@ class SQLBaseStore(metaclass=ABCMeta):
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
if members_changed: if members_changed:
self._attempt_to_invalidate_cache("get_users_in_room", (room_id,)) self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
self._attempt_to_invalidate_cache("get_current_hosts_in_room", (room_id,))
self._attempt_to_invalidate_cache( self._attempt_to_invalidate_cache(
"get_users_in_room_with_profiles", (room_id,) "get_users_in_room_with_profiles", (room_id,)
) )

View file

@ -23,6 +23,7 @@ from typing import (
List, List,
Mapping, Mapping,
Optional, Optional,
Set,
Tuple, Tuple,
) )
@ -482,3 +483,10 @@ class StateStorageController:
room_id, StateFilter.from_types((key,)) room_id, StateFilter.from_types((key,))
) )
return state_map.get(key) return state_map.get(key)
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
"""Get current hosts in room based on current state."""
await self._partial_state_room_tracker.await_full_state(room_id)
return await self.stores.main.get_current_hosts_in_room(room_id)

View file

@ -893,6 +893,43 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True return True
@cached(iterable=True, max_entries=10000)
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
"""Get current hosts in room based on current state."""
# First we check if we already have `get_users_in_room` in the cache, as
# we can just calculate result from that
users = self.get_users_in_room.cache.get_immediate(
(room_id,), None, update_metrics=False
)
if users is not None:
return {get_domain_from_id(u) for u in users}
if isinstance(self.database_engine, Sqlite3Engine):
# If we're using SQLite then let's just always use
# `get_users_in_room` rather than funky SQL.
users = await self.get_users_in_room(room_id)
return {get_domain_from_id(u) for u in users}
# For PostgreSQL we can use a regex to pull out the domains from the
# joined users in `current_state_events` via regex.
def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
sql = """
SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
FROM current_state_events
WHERE
type = 'm.room.member'
AND membership = 'join'
AND room_id = ?
"""
txn.execute(sql, (room_id,))
return {d for d, in txn}
return await self.db_pool.runInteraction(
"get_current_hosts_in_room", get_current_hosts_in_room_txn
)
async def get_joined_hosts( async def get_joined_hosts(
self, room_id: str, state_entry: "_StateCacheEntry" self, room_id: str, state_entry: "_StateCacheEntry"
) -> FrozenSet[str]: ) -> FrozenSet[str]:

View file

@ -30,16 +30,16 @@ from tests.unittest import HomeserverTestCase, override_config
class FederationSenderReceiptsTestCases(HomeserverTestCase): class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
mock_state_handler = Mock(spec=["get_current_hosts_in_room"]) hs = self.setup_test_homeserver(
# Ensure a new Awaitable is created for each call.
mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable(
["test", "host2"]
)
return self.setup_test_homeserver(
state_handler=mock_state_handler,
federation_transport_client=Mock(spec=["send_transaction"]), federation_transport_client=Mock(spec=["send_transaction"]),
) )
hs.get_storage_controllers().state.get_current_hosts_in_room = Mock(
return_value=make_awaitable({"test", "host2"})
)
return hs
@override_config({"send_federation": True}) @override_config({"send_federation": True})
def test_send_receipts(self): def test_send_receipts(self):
mock_send_transaction = ( mock_send_transaction = (

View file

@ -129,10 +129,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
hs.get_event_auth_handler().check_host_in_room = check_host_in_room hs.get_event_auth_handler().check_host_in_room = check_host_in_room
def get_joined_hosts_for_room(room_id: str): async def get_current_hosts_in_room(room_id: str):
return {member.domain for member in self.room_members} return {member.domain for member in self.room_members}
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room hs.get_storage_controllers().state.get_current_hosts_in_room = (
get_current_hosts_in_room
)
async def get_users_in_room(room_id: str): async def get_users_in_room(room_id: str):
return {str(u) for u in self.room_members} return {str(u) for u in self.room_members}