Make chain cover index bg update go faster (#9124)

We do this by allowing a single iteration to process multiple rooms at a
time, as there are often a lot of really tiny rooms, which can massively
slow things down.
This commit is contained in:
Erik Johnston 2021-01-15 17:18:37 +00:00 committed by GitHub
parent 2de7e263ed
commit 350d9923cd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 366 additions and 181 deletions

1
changelog.d/9124.misc Normal file
View file

@ -0,0 +1 @@
Improve efficiency of large state resolutions.

View file

@ -16,6 +16,8 @@
import logging import logging
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import attr
from synapse.api.constants import EventContentFields from synapse.api.constants import EventContentFields
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict from synapse.events import make_event_from_dict
@ -28,6 +30,25 @@ from synapse.types import JsonDict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True)
class _CalculateChainCover:
"""Return value for _calculate_chain_cover_txn.
"""
# The last room_id/depth/stream processed.
room_id = attr.ib(type=str)
depth = attr.ib(type=int)
stream = attr.ib(type=int)
# Number of rows processed
processed_count = attr.ib(type=int)
# Map from room_id to last depth/stream processed for each room that we have
# processed all events for (i.e. the rooms we can flip the
# `has_auth_chain_index` for)
finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]])
class EventsBackgroundUpdatesStore(SQLBaseStore): class EventsBackgroundUpdatesStore(SQLBaseStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
@ -719,53 +740,94 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
current_room_id = progress.get("current_room_id", "") current_room_id = progress.get("current_room_id", "")
# Have we finished processing the current room.
finished = progress.get("finished", True)
# Where we've processed up to in the room, defaults to the start of the # Where we've processed up to in the room, defaults to the start of the
# room. # room.
last_depth = progress.get("last_depth", -1) last_depth = progress.get("last_depth", -1)
last_stream = progress.get("last_stream", -1) last_stream = progress.get("last_stream", -1)
# Have we set the `has_auth_chain_index` for the room yet. result = await self.db_pool.runInteraction(
has_set_room_has_chain_index = progress.get( "_chain_cover_index",
"has_set_room_has_chain_index", False self._calculate_chain_cover_txn,
current_room_id,
last_depth,
last_stream,
batch_size,
single_room=False,
) )
finished = result.processed_count == 0
total_rows_processed = result.processed_count
current_room_id = result.room_id
last_depth = result.depth
last_stream = result.stream
for room_id, (depth, stream) in result.finished_room_map.items():
# If we've done all the events in the room we flip the
# `has_auth_chain_index` in the DB. Note that its possible for
# further events to be persisted between the above and setting the
# flag without having the chain cover calculated for them. This is
# fine as a) the code gracefully handles these cases and b) we'll
# calculate them below.
await self.db_pool.simple_update(
table="rooms",
keyvalues={"room_id": room_id},
updatevalues={"has_auth_chain_index": True},
desc="_chain_cover_index",
)
# Handle any events that might have raced with us flipping the
# bit above.
result = await self.db_pool.runInteraction(
"_chain_cover_index",
self._calculate_chain_cover_txn,
room_id,
depth,
stream,
batch_size=None,
single_room=True,
)
total_rows_processed += result.processed_count
if finished: if finished:
# If we've finished with the previous room (or its our first
# iteration) we move on to the next room.
def _get_next_room(txn: Cursor) -> Optional[str]:
sql = """
SELECT room_id FROM rooms
WHERE room_id > ?
AND (
NOT has_auth_chain_index
OR has_auth_chain_index IS NULL
)
ORDER BY room_id
LIMIT 1
"""
txn.execute(sql, (current_room_id,))
row = txn.fetchone()
if row:
return row[0]
return None
current_room_id = await self.db_pool.runInteraction(
"_chain_cover_index", _get_next_room
)
if not current_room_id:
await self.db_pool.updates._end_background_update("chain_cover") await self.db_pool.updates._end_background_update("chain_cover")
return 0 return total_rows_processed
logger.debug("Adding chain cover to %s", current_room_id) await self.db_pool.updates._background_update_progress(
"chain_cover",
{
"current_room_id": current_room_id,
"last_depth": last_depth,
"last_stream": last_stream,
},
)
return total_rows_processed
def _calculate_chain_cover_txn(
self,
txn: Cursor,
last_room_id: str,
last_depth: int,
last_stream: int,
batch_size: Optional[int],
single_room: bool,
) -> _CalculateChainCover:
"""Calculate the chain cover for `batch_size` events, ordered by
`(room_id, depth, stream)`.
Args:
txn,
last_room_id, last_depth, last_stream: The `(room_id, depth, stream)`
tuple to fetch results after.
batch_size: The maximum number of events to process. If None then
no limit.
single_room: Whether to calculate the index for just the given
room.
"""
def _calculate_auth_chain(
txn: Cursor, last_depth: int, last_stream: int
) -> Tuple[int, int, int]:
# Get the next set of events in the room (that we haven't already # Get the next set of events in the room (that we haven't already
# computed chain cover for). We do this in topological order. # computed chain cover for). We do this in topological order.
@ -774,43 +836,66 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
tuple_clause, tuple_args = make_tuple_comparison_clause( tuple_clause, tuple_args = make_tuple_comparison_clause(
self.database_engine, self.database_engine,
[ [
("events.room_id", last_room_id),
("topological_ordering", last_depth), ("topological_ordering", last_depth),
("stream_ordering", last_stream), ("stream_ordering", last_stream),
], ],
) )
extra_clause = ""
if single_room:
extra_clause = "AND events.room_id = ?"
tuple_args.append(last_room_id)
sql = """ sql = """
SELECT SELECT
event_id, state_events.type, state_events.state_key, event_id, state_events.type, state_events.state_key,
topological_ordering, stream_ordering topological_ordering, stream_ordering,
events.room_id
FROM events FROM events
INNER JOIN state_events USING (event_id) INNER JOIN state_events USING (event_id)
LEFT JOIN event_auth_chains USING (event_id) LEFT JOIN event_auth_chains USING (event_id)
LEFT JOIN event_auth_chain_to_calculate USING (event_id) LEFT JOIN event_auth_chain_to_calculate USING (event_id)
WHERE events.room_id = ? WHERE event_auth_chains.event_id IS NULL
AND event_auth_chains.event_id IS NULL
AND event_auth_chain_to_calculate.event_id IS NULL AND event_auth_chain_to_calculate.event_id IS NULL
AND %(tuple_cmp)s AND %(tuple_cmp)s
ORDER BY topological_ordering, stream_ordering %(extra)s
LIMIT ? ORDER BY events.room_id, topological_ordering, stream_ordering
%(limit)s
""" % { """ % {
"tuple_cmp": tuple_clause, "tuple_cmp": tuple_clause,
"limit": "LIMIT ?" if batch_size is not None else "",
"extra": extra_clause,
} }
args = [current_room_id] if batch_size is not None:
args.extend(tuple_args) tuple_args.append(batch_size)
args.append(batch_size)
txn.execute(sql, args) txn.execute(sql, tuple_args)
rows = txn.fetchall() rows = txn.fetchall()
# Put the results in the necessary format for # Put the results in the necessary format for
# `_add_chain_cover_index` # `_add_chain_cover_index`
event_to_room_id = {row[0]: current_room_id for row in rows} event_to_room_id = {row[0]: row[5] for row in rows}
event_to_types = {row[0]: (row[1], row[2]) for row in rows} event_to_types = {row[0]: (row[1], row[2]) for row in rows}
# Calculate the new last position we've processed up to.
new_last_depth = rows[-1][3] if rows else last_depth # type: int new_last_depth = rows[-1][3] if rows else last_depth # type: int
new_last_stream = rows[-1][4] if rows else last_stream # type: int new_last_stream = rows[-1][4] if rows else last_stream # type: int
new_last_room_id = rows[-1][5] if rows else "" # type: str
# Map from room_id to last depth/stream_ordering processed for the room,
# excluding the last room (which we're likely still processing). We also
# need to include the room passed in if it's not included in the result
# set (as we then know we've processed all events in said room).
#
# This is the set of rooms that we can now safely flip the
# `has_auth_chain_index` bit for.
finished_rooms = {
row[5]: (row[3], row[4]) for row in rows if row[5] != new_last_room_id
}
if last_room_id not in finished_rooms and last_room_id != new_last_room_id:
finished_rooms[last_room_id] = (last_depth, last_stream)
count = len(rows) count = len(rows)
@ -826,76 +911,20 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
event_to_auth_chain = {} # type: Dict[str, List[str]] event_to_auth_chain = {} # type: Dict[str, List[str]]
for row in auth_events: for row in auth_events:
event_to_auth_chain.setdefault(row["event_id"], []).append( event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
row["auth_id"]
)
# Calculate and persist the chain cover index for this set of events. # Calculate and persist the chain cover index for this set of events.
# #
# Annoyingly we need to gut wrench into the persit event store so that # Annoyingly we need to gut wrench into the persit event store so that
# we can reuse the function to calculate the chain cover for rooms. # we can reuse the function to calculate the chain cover for rooms.
PersistEventsStore._add_chain_cover_index( PersistEventsStore._add_chain_cover_index(
txn, txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
self.db_pool,
event_to_room_id,
event_to_types,
event_to_auth_chain,
) )
return new_last_depth, new_last_stream, count return _CalculateChainCover(
room_id=new_last_room_id,
last_depth, last_stream, count = await self.db_pool.runInteraction( depth=new_last_depth,
"_chain_cover_index", _calculate_auth_chain, last_depth, last_stream stream=new_last_stream,
processed_count=count,
finished_room_map=finished_rooms,
) )
total_rows_processed = count
if count < batch_size and not has_set_room_has_chain_index:
# If we've done all the events in the room we flip the
# `has_auth_chain_index` in the DB. Note that its possible for
# further events to be persisted between the above and setting the
# flag without having the chain cover calculated for them. This is
# fine as a) the code gracefully handles these cases and b) we'll
# calculate them below.
await self.db_pool.simple_update(
table="rooms",
keyvalues={"room_id": current_room_id},
updatevalues={"has_auth_chain_index": True},
desc="_chain_cover_index",
)
has_set_room_has_chain_index = True
# Handle any events that might have raced with us flipping the
# bit above.
last_depth, last_stream, count = await self.db_pool.runInteraction(
"_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
)
total_rows_processed += count
# Note that at this point its technically possible that more events
# than our `batch_size` have been persisted without their chain
# cover, so we need to continue processing this room if the last
# count returned was equal to the `batch_size`.
if count < batch_size:
# We've finished calculating the index for this room, move on to the
# next room.
await self.db_pool.updates._background_update_progress(
"chain_cover", {"current_room_id": current_room_id, "finished": True},
)
else:
# We still have outstanding events to calculate the index for.
await self.db_pool.updates._background_update_progress(
"chain_cover",
{
"current_room_id": current_room_id,
"last_depth": last_depth,
"last_stream": last_stream,
"has_auth_chain_index": has_set_room_has_chain_index,
"finished": False,
},
)
return total_rows_processed

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict, List, Tuple from typing import Dict, List, Set, Tuple
from twisted.trial import unittest from twisted.trial import unittest
@ -483,22 +483,20 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def test_background_update(self): def prepare(self, reactor, clock, hs):
"""Test that the background update to calculate auth chains for historic self.store = hs.get_datastore()
rooms works correctly. self.user_id = self.register_user("foo", "pass")
self.token = self.login("foo", "pass")
self.requester = create_requester(self.user_id)
def _generate_room(self) -> Tuple[str, List[Set[str]]]:
"""Insert a room without a chain cover index.
""" """
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
# Create a room
user_id = self.register_user("foo", "pass")
token = self.login("foo", "pass")
room_id = self.helper.create_room_as(user_id, tok=token)
requester = create_requester(user_id)
store = self.hs.get_datastore()
# Mark the room as not having a chain cover index # Mark the room as not having a chain cover index
self.get_success( self.get_success(
store.db_pool.simple_update( self.store.db_pool.simple_update(
table="rooms", table="rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
updatevalues={"has_auth_chain_index": False}, updatevalues={"has_auth_chain_index": False},
@ -508,42 +506,44 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
# Create a fork in the DAG with different events. # Create a fork in the DAG with different events.
event_handler = self.hs.get_event_creation_handler() event_handler = self.hs.get_event_creation_handler()
latest_event_ids = self.get_success(store.get_prev_events_for_room(room_id)) latest_event_ids = self.get_success(
self.store.get_prev_events_for_room(room_id)
)
event, context = self.get_success( event, context = self.get_success(
event_handler.create_event( event_handler.create_event(
requester, self.requester,
{ {
"type": "some_state_type", "type": "some_state_type",
"state_key": "", "state_key": "",
"content": {}, "content": {},
"room_id": room_id, "room_id": room_id,
"sender": user_id, "sender": self.user_id,
}, },
prev_event_ids=latest_event_ids, prev_event_ids=latest_event_ids,
) )
) )
self.get_success( self.get_success(
event_handler.handle_new_client_event(requester, event, context) event_handler.handle_new_client_event(self.requester, event, context)
) )
state1 = list(self.get_success(context.get_current_state_ids()).values()) state1 = set(self.get_success(context.get_current_state_ids()).values())
event, context = self.get_success( event, context = self.get_success(
event_handler.create_event( event_handler.create_event(
requester, self.requester,
{ {
"type": "some_state_type", "type": "some_state_type",
"state_key": "", "state_key": "",
"content": {}, "content": {},
"room_id": room_id, "room_id": room_id,
"sender": user_id, "sender": self.user_id,
}, },
prev_event_ids=latest_event_ids, prev_event_ids=latest_event_ids,
) )
) )
self.get_success( self.get_success(
event_handler.handle_new_client_event(requester, event, context) event_handler.handle_new_client_event(self.requester, event, context)
) )
state2 = list(self.get_success(context.get_current_state_ids()).values()) state2 = set(self.get_success(context.get_current_state_ids()).values())
# Delete the chain cover info. # Delete the chain cover info.
@ -551,36 +551,191 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
txn.execute("DELETE FROM event_auth_chains") txn.execute("DELETE FROM event_auth_chains")
txn.execute("DELETE FROM event_auth_chain_links") txn.execute("DELETE FROM event_auth_chain_links")
self.get_success(store.db_pool.runInteraction("test", _delete_tables)) self.get_success(self.store.db_pool.runInteraction("test", _delete_tables))
return room_id, [state1, state2]
def test_background_update_single_room(self):
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
# Create a room
room_id, states = self._generate_room()
# Insert and run the background update. # Insert and run the background update.
self.get_success( self.get_success(
store.db_pool.simple_insert( self.store.db_pool.simple_insert(
"background_updates", "background_updates",
{"update_name": "chain_cover", "progress_json": "{}"}, {"update_name": "chain_cover", "progress_json": "{}"},
) )
) )
# Ugh, have to reset this flag # Ugh, have to reset this flag
store.db_pool.updates._all_done = False self.store.db_pool.updates._all_done = False
while not self.get_success( while not self.get_success(
store.db_pool.updates.has_completed_background_updates() self.store.db_pool.updates.has_completed_background_updates()
): ):
self.get_success( self.get_success(
store.db_pool.updates.do_next_background_update(100), by=0.1 self.store.db_pool.updates.do_next_background_update(100), by=0.1
) )
# Test that the `has_auth_chain_index` has been set # Test that the `has_auth_chain_index` has been set
self.assertTrue(self.get_success(store.has_auth_chain_index(room_id))) self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
# Test that calculating the auth chain difference using the newly # Test that calculating the auth chain difference using the newly
# calculated chain cover works. # calculated chain cover works.
self.get_success( self.get_success(
store.db_pool.runInteraction( self.store.db_pool.runInteraction(
"test", "test",
store._get_auth_chain_difference_using_cover_index_txn, self.store._get_auth_chain_difference_using_cover_index_txn,
room_id, room_id,
[state1, state2], states,
) )
) )
def test_background_update_multiple_rooms(self):
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
# Create a room
room_id1, states1 = self._generate_room()
room_id2, states2 = self._generate_room()
room_id3, states2 = self._generate_room()
# Insert and run the background update.
self.get_success(
self.store.db_pool.simple_insert(
"background_updates",
{"update_name": "chain_cover", "progress_json": "{}"},
)
)
# Ugh, have to reset this flag
self.store.db_pool.updates._all_done = False
while not self.get_success(
self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# Test that the `has_auth_chain_index` has been set
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id3)))
# Test that calculating the auth chain difference using the newly
# calculated chain cover works.
self.get_success(
self.store.db_pool.runInteraction(
"test",
self.store._get_auth_chain_difference_using_cover_index_txn,
room_id1,
states1,
)
)
def test_background_update_single_large_room(self):
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
# Create a room
room_id, states = self._generate_room()
# Add a bunch of state so that it takes multiple iterations of the
# background update to process the room.
for i in range(0, 150):
self.helper.send_state(
room_id, event_type="m.test", body={"index": i}, tok=self.token
)
# Insert and run the background update.
self.get_success(
self.store.db_pool.simple_insert(
"background_updates",
{"update_name": "chain_cover", "progress_json": "{}"},
)
)
# Ugh, have to reset this flag
self.store.db_pool.updates._all_done = False
iterations = 0
while not self.get_success(
self.store.db_pool.updates.has_completed_background_updates()
):
iterations += 1
self.get_success(
self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# Ensure that we did actually take multiple iterations to process the
# room.
self.assertGreater(iterations, 1)
# Test that the `has_auth_chain_index` has been set
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
# Test that calculating the auth chain difference using the newly
# calculated chain cover works.
self.get_success(
self.store.db_pool.runInteraction(
"test",
self.store._get_auth_chain_difference_using_cover_index_txn,
room_id,
states,
)
)
def test_background_update_multiple_large_room(self):
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
# Create the rooms
room_id1, _ = self._generate_room()
room_id2, _ = self._generate_room()
# Add a bunch of state so that it takes multiple iterations of the
# background update to process the room.
for i in range(0, 150):
self.helper.send_state(
room_id1, event_type="m.test", body={"index": i}, tok=self.token
)
for i in range(0, 150):
self.helper.send_state(
room_id2, event_type="m.test", body={"index": i}, tok=self.token
)
# Insert and run the background update.
self.get_success(
self.store.db_pool.simple_insert(
"background_updates",
{"update_name": "chain_cover", "progress_json": "{}"},
)
)
# Ugh, have to reset this flag
self.store.db_pool.updates._all_done = False
iterations = 0
while not self.get_success(
self.store.db_pool.updates.has_completed_background_updates()
):
iterations += 1
self.get_success(
self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# Ensure that we did actually take multiple iterations to process the
# room.
self.assertGreater(iterations, 1)
# Test that the `has_auth_chain_index` has been set
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))