Make StreamIdGen get_next and get_next_mult async (#8161)

This is mainly so that `StreamIdGenerator` and `MultiWriterIdGenerator`
will have the same interface, allowing them to be used interchangeably.
This commit is contained in:
Erik Johnston 2020-08-25 15:10:08 +01:00 committed by GitHub
parent 74bf8d4d06
commit 2231dffee6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 54 additions and 49 deletions

1
changelog.d/8161.misc Normal file
View file

@ -0,0 +1 @@
Refactor `StreamIdGenerator` and `MultiWriterIdGenerator` to have the same interface.

View file

@ -336,7 +336,7 @@ class AccountDataStore(AccountDataWorkerStore):
""" """
content_json = json_encoder.encode(content) content_json = json_encoder.encode(content)
with self._account_data_id_gen.get_next() as next_id: with await self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint # no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will # on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict. # retry if there is a conflict.
@ -384,7 +384,7 @@ class AccountDataStore(AccountDataWorkerStore):
""" """
content_json = json_encoder.encode(content) content_json = json_encoder.encode(content)
with self._account_data_id_gen.get_next() as next_id: with await self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on # no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if # (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict. # there is a conflict.

View file

@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
rows.append((destination, stream_id, now_ms, edu_json)) rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows) txn.executemany(sql, rows)
with self._device_inbox_id_gen.get_next() as stream_id: with await self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec() now_ms = self.clock.time_msec()
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
txn, stream_id, local_messages_by_user_then_device txn, stream_id, local_messages_by_user_then_device
) )
with self._device_inbox_id_gen.get_next() as stream_id: with await self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec() now_ms = self.clock.time_msec()
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox", "add_messages_from_remote_to_device_inbox",

View file

@ -380,7 +380,7 @@ class DeviceWorkerStore(SQLBaseStore):
THe new stream ID. THe new stream ID.
""" """
with self._device_list_id_gen.get_next() as stream_id: with await self._device_list_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"add_user_sig_change_to_streams", "add_user_sig_change_to_streams",
self._add_user_signature_change_txn, self._add_user_signature_change_txn,
@ -1146,7 +1146,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not device_ids: if not device_ids:
return return
with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids: with await self._device_list_id_gen.get_next_mult(
len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"add_device_change_to_stream", "add_device_change_to_stream",
self._add_device_change_to_stream_txn, self._add_device_change_to_stream_txn,
@ -1159,7 +1161,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return stream_ids[-1] return stream_ids[-1]
context = get_active_span_text_map() context = get_active_span_text_map()
with self._device_list_id_gen.get_next_mult( with await self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids) len(hosts) * len(device_ids)
) as stream_ids: ) as stream_ids:
await self.db_pool.runInteraction( await self.db_pool.runInteraction(

View file

@ -648,7 +648,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
) )
def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key): def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
"""Set a user's cross-signing key. """Set a user's cross-signing key.
Args: Args:
@ -658,6 +658,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
for a master key, 'self_signing' for a self-signing key, or for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key 'user_signing' for a user-signing key
key (dict): the key data key (dict): the key data
stream_id (int)
""" """
# the 'key' dict will look something like: # the 'key' dict will look something like:
# { # {
@ -695,23 +696,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
) )
# and finally, store the key itself # and finally, store the key itself
with self._cross_signing_id_gen.get_next() as stream_id: self.db_pool.simple_insert_txn(
self.db_pool.simple_insert_txn( txn,
txn, "e2e_cross_signing_keys",
"e2e_cross_signing_keys", values={
values={ "user_id": user_id,
"user_id": user_id, "keytype": key_type,
"keytype": key_type, "keydata": json_encoder.encode(key),
"keydata": json_encoder.encode(key), "stream_id": stream_id,
"stream_id": stream_id, },
}, )
)
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self._get_bare_e2e_cross_signing_keys, (user_id,) txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
) )
def set_e2e_cross_signing_key(self, user_id, key_type, key): async def set_e2e_cross_signing_key(self, user_id, key_type, key):
"""Set a user's cross-signing key. """Set a user's cross-signing key.
Args: Args:
@ -719,13 +719,16 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
key_type (str): the type of cross-signing key to set key_type (str): the type of cross-signing key to set
key (dict): the key data key (dict): the key data
""" """
return self.db_pool.runInteraction(
"add_e2e_cross_signing_key", with await self._cross_signing_id_gen.get_next() as stream_id:
self._set_e2e_cross_signing_key_txn, return await self.db_pool.runInteraction(
user_id, "add_e2e_cross_signing_key",
key_type, self._set_e2e_cross_signing_key_txn,
key, user_id,
) key_type,
key,
stream_id,
)
def store_e2e_cross_signing_signatures(self, user_id, signatures): def store_e2e_cross_signing_signatures(self, user_id, signatures):
"""Stores cross-signing signatures. """Stores cross-signing signatures.

View file

@ -153,11 +153,11 @@ class PersistEventsStore:
# Note: Multiple instances of this function cannot be in flight at # Note: Multiple instances of this function cannot be in flight at
# the same time for the same room. # the same time for the same room.
if backfilled: if backfilled:
stream_ordering_manager = self._backfill_id_gen.get_next_mult( stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
len(events_and_contexts) len(events_and_contexts)
) )
else: else:
stream_ordering_manager = self._stream_id_gen.get_next_mult( stream_ordering_manager = await self._stream_id_gen.get_next_mult(
len(events_and_contexts) len(events_and_contexts)
) )

View file

@ -1182,7 +1182,7 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id return next_id
with self._group_updates_id_gen.get_next() as next_id: with await self._group_updates_id_gen.get_next() as next_id:
res = await self.db_pool.runInteraction( res = await self.db_pool.runInteraction(
"register_user_group_membership", "register_user_group_membership",
_register_user_group_membership_txn, _register_user_group_membership_txn,

View file

@ -23,7 +23,7 @@ from synapse.util.iterutils import batch_iter
class PresenceStore(SQLBaseStore): class PresenceStore(SQLBaseStore):
async def update_presence(self, presence_states): async def update_presence(self, presence_states):
stream_ordering_manager = self._presence_id_gen.get_next_mult( stream_ordering_manager = await self._presence_id_gen.get_next_mult(
len(presence_states) len(presence_states)
) )

View file

@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
) -> None: ) -> None:
conditions_json = json_encoder.encode(conditions) conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions) actions_json = json_encoder.encode(actions)
with self._push_rules_stream_id_gen.get_next() as stream_id: with await self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token() event_stream_ordering = self._stream_id_gen.get_current_token()
if before or after: if before or after:
@ -560,7 +560,7 @@ class PushRuleStore(PushRulesWorkerStore):
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
) )
with self._push_rules_stream_id_gen.get_next() as stream_id: with await self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token() event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
@ -571,7 +571,7 @@ class PushRuleStore(PushRulesWorkerStore):
) )
async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None: async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
with self._push_rules_stream_id_gen.get_next() as stream_id: with await self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token() event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
@ -646,7 +646,7 @@ class PushRuleStore(PushRulesWorkerStore):
data={"actions": actions_json}, data={"actions": actions_json},
) )
with self._push_rules_stream_id_gen.get_next() as stream_id: with await self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token() event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction( await self.db_pool.runInteraction(

View file

@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
last_stream_ordering, last_stream_ordering,
profile_tag="", profile_tag="",
) -> None: ) -> None:
with self._pushers_id_gen.get_next() as stream_id: with await self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on # no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry # (app_id, pushkey, user_name) so simple_upsert will retry
await self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
}, },
) )
with self._pushers_id_gen.get_next() as stream_id: with await self._pushers_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id "delete_pusher", delete_pusher_txn, stream_id
) )

View file

@ -520,8 +520,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"insert_receipt_conv", graph_to_linear "insert_receipt_conv", graph_to_linear
) )
stream_id_manager = self._receipts_id_gen.get_next() with await self._receipts_id_gen.get_next() as stream_id:
with stream_id_manager as stream_id:
event_ts = await self.db_pool.runInteraction( event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt", "insert_linearized_receipt",
self.insert_linearized_receipt_txn, self.insert_linearized_receipt_txn,

View file

@ -1129,7 +1129,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
}, },
) )
with self._public_room_id_gen.get_next() as next_id: with await self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"store_room_txn", store_room_txn, next_id "store_room_txn", store_room_txn, next_id
) )
@ -1196,7 +1196,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
}, },
) )
with self._public_room_id_gen.get_next() as next_id: with await self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"set_room_is_public", set_room_is_public_txn, next_id "set_room_is_public", set_room_is_public_txn, next_id
) )
@ -1276,7 +1276,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
}, },
) )
with self._public_room_id_gen.get_next() as next_id: with await self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"set_room_is_public_appservice", "set_room_is_public_appservice",
set_room_is_public_appservice_txn, set_room_is_public_appservice_txn,

View file

@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore):
) )
self._update_revision_txn(txn, user_id, room_id, next_id) self._update_revision_txn(txn, user_id, room_id, next_id)
with self._account_data_id_gen.get_next() as next_id: with await self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,)) self.get_tags_for_user.invalidate((user_id,))
@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore):
txn.execute(sql, (user_id, room_id, tag)) txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id) self._update_revision_txn(txn, user_id, room_id, next_id)
with self._account_data_id_gen.get_next() as next_id: with await self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,)) self.get_tags_for_user.invalidate((user_id,))

View file

@ -80,7 +80,7 @@ class StreamIdGenerator(object):
upwards, -1 to grow downwards. upwards, -1 to grow downwards.
Usage: Usage:
with stream_id_gen.get_next() as stream_id: with await stream_id_gen.get_next() as stream_id:
# ... persist event ... # ... persist event ...
""" """
@ -95,10 +95,10 @@ class StreamIdGenerator(object):
) )
self._unfinished_ids = deque() # type: Deque[int] self._unfinished_ids = deque() # type: Deque[int]
def get_next(self): async def get_next(self):
""" """
Usage: Usage:
with stream_id_gen.get_next() as stream_id: with await stream_id_gen.get_next() as stream_id:
# ... persist event ... # ... persist event ...
""" """
with self._lock: with self._lock:
@ -117,10 +117,10 @@ class StreamIdGenerator(object):
return manager() return manager()
def get_next_mult(self, n): async def get_next_mult(self, n):
""" """
Usage: Usage:
with stream_id_gen.get_next(n) as stream_ids: with await stream_id_gen.get_next(n) as stream_ids:
# ... persist events ... # ... persist events ...
""" """
with self._lock: with self._lock: