Collect the invalidate callbacks on the transaction object rather than passing around a separate list

This commit is contained in:
Mark Haines 2015-05-05 17:32:21 +01:00
parent 041b6cba61
commit d18f37e026
7 changed files with 51 additions and 51 deletions

View file

@ -185,12 +185,16 @@ class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object """An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute() passed to the constructor. Adds logging and metrics to the .execute()
method.""" method."""
__slots__ = ["txn", "name", "database_engine"] __slots__ = ["txn", "name", "database_engine", "after_callbacks"]
def __init__(self, txn, name, database_engine): def __init__(self, txn, name, database_engine, after_callbacks):
object.__setattr__(self, "txn", txn) object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name) object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine) object.__setattr__(self, "database_engine", database_engine)
object.__setattr__(self, "after_callbacks", after_callbacks)
def call_after(self, callback, *args):
self.after_callbacks.append((callback, args))
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.txn, name) return getattr(self.txn, name)
@ -336,6 +340,8 @@ class SQLBaseStore(object):
start_time = time.time() * 1000 start_time = time.time() * 1000
after_callbacks = []
def inner_func(conn, *args, **kwargs): def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context: with LoggingContext("runInteraction") as context:
if self.database_engine.is_connection_closed(conn): if self.database_engine.is_connection_closed(conn):
@ -360,10 +366,10 @@ class SQLBaseStore(object):
while True: while True:
try: try:
txn = conn.cursor() txn = conn.cursor()
return func( txn = LoggingTransaction(
LoggingTransaction(txn, name, self.database_engine), txn, name, self.database_engine, after_callbacks
*args, **kwargs
) )
return func(txn, *args, **kwargs)
except self.database_engine.module.OperationalError as e: except self.database_engine.module.OperationalError as e:
# This can happen if the database disappears mid # This can happen if the database disappears mid
# transaction. # transaction.
@ -412,6 +418,8 @@ class SQLBaseStore(object):
result = yield self._db_pool.runWithConnection( result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs inner_func, *args, **kwargs
) )
for after_callback, after_args in after_callbacks:
after_callback(*after_args)
defer.returnValue(result) defer.returnValue(result)
def cursor_to_dict(self, cursor): def cursor_to_dict(self, cursor):

View file

@ -241,7 +241,7 @@ class EventFederationStore(SQLBaseStore):
return int(min_depth) if min_depth is not None else None return int(min_depth) if min_depth is not None else None
def _update_min_depth_for_room_txn(self, txn, invalidates, room_id, depth): def _update_min_depth_for_room_txn(self, txn, room_id, depth):
min_depth = self._get_min_depth_interaction(txn, room_id) min_depth = self._get_min_depth_interaction(txn, room_id)
do_insert = depth < min_depth if min_depth else True do_insert = depth < min_depth if min_depth else True
@ -256,8 +256,8 @@ class EventFederationStore(SQLBaseStore):
}, },
) )
def _handle_prev_events(self, txn, invalidates, outlier, event_id, def _handle_prev_events(self, txn, outlier, event_id, prev_events,
prev_events, room_id): room_id):
""" """
For the given event, update the event edges table and forward and For the given event, update the event edges table and forward and
backward extremities tables. backward extremities tables.
@ -330,9 +330,9 @@ class EventFederationStore(SQLBaseStore):
) )
txn.execute(query) txn.execute(query)
invalidates.append(( txn.call_after(
self.get_latest_event_ids_in_room.invalidate, room_id self.get_latest_event_ids_in_room.invalidate, room_id
)) )
def get_backfill_events(self, room_id, event_list, limit): def get_backfill_events(self, room_id, event_list, limit):
"""Get a list of Events for a given topic that occurred before (and """Get a list of Events for a given topic that occurred before (and

View file

@ -42,7 +42,7 @@ class EventsStore(SQLBaseStore):
stream_ordering = self.min_token stream_ordering = self.min_token
try: try:
invalidates = yield self.runInteraction( yield self.runInteraction(
"persist_event", "persist_event",
self._persist_event_txn, self._persist_event_txn,
event=event, event=event,
@ -52,11 +52,6 @@ class EventsStore(SQLBaseStore):
is_new_state=is_new_state, is_new_state=is_new_state,
current_state=current_state, current_state=current_state,
) )
for invalidated in invalidates:
invalidated_callback = invalidated[0]
invalidated_args = invalidated[1:]
invalidated_callback(*invalidated_args)
except _RollbackButIsFineException: except _RollbackButIsFineException:
pass pass
@ -96,10 +91,9 @@ class EventsStore(SQLBaseStore):
def _persist_event_txn(self, txn, event, context, backfilled, def _persist_event_txn(self, txn, event, context, backfilled,
stream_ordering=None, is_new_state=True, stream_ordering=None, is_new_state=True,
current_state=None): current_state=None):
invalidates = []
# Remove the any existing cache entries for the event_id # Remove the any existing cache entries for the event_id
invalidates.append((self._invalidate_get_event_cache, event.event_id)) txn.call_after(self._invalidate_get_event_cache, event.event_id)
if stream_ordering is None: if stream_ordering is None:
with self._stream_id_gen.get_next_txn(txn) as stream_ordering: with self._stream_id_gen.get_next_txn(txn) as stream_ordering:
@ -121,10 +115,12 @@ class EventsStore(SQLBaseStore):
for s in current_state: for s in current_state:
if s.type == EventTypes.Member: if s.type == EventTypes.Member:
invalidates.extend([ txn.call_after(
(self.get_rooms_for_user.invalidate, s.state_key), self.get_rooms_for_user.invalidate, s.state_key
(self.get_joined_hosts_for_room.invalidate, s.room_id), )
]) txn.call_after(
self.get_joined_hosts_for_room.invalidate, s.room_id
)
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
"current_state_events", "current_state_events",
@ -161,11 +157,10 @@ class EventsStore(SQLBaseStore):
outlier = event.internal_metadata.is_outlier() outlier = event.internal_metadata.is_outlier()
if not outlier: if not outlier:
self._store_state_groups_txn(txn, invalidates, event, context) self._store_state_groups_txn(txn, event, context)
self._update_min_depth_for_room_txn( self._update_min_depth_for_room_txn(
txn, txn,
invalidates,
event.room_id, event.room_id,
event.depth event.depth
) )
@ -207,11 +202,10 @@ class EventsStore(SQLBaseStore):
sql, sql,
(False, event.event_id,) (False, event.event_id,)
) )
return invalidates return
self._handle_prev_events( self._handle_prev_events(
txn, txn,
invalidates,
outlier=outlier, outlier=outlier,
event_id=event.event_id, event_id=event.event_id,
prev_events=event.prev_events, prev_events=event.prev_events,
@ -219,13 +213,13 @@ class EventsStore(SQLBaseStore):
) )
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
self._store_room_member_txn(txn, invalidates, event) self._store_room_member_txn(txn, event)
elif event.type == EventTypes.Name: elif event.type == EventTypes.Name:
self._store_room_name_txn(txn, invalidates, event) self._store_room_name_txn(txn, event)
elif event.type == EventTypes.Topic: elif event.type == EventTypes.Topic:
self._store_room_topic_txn(txn, invalidates, event) self._store_room_topic_txn(txn, event)
elif event.type == EventTypes.Redaction: elif event.type == EventTypes.Redaction:
self._store_redaction(txn, invalidates, event) self._store_redaction(txn, event)
event_dict = { event_dict = {
k: v k: v
@ -295,20 +289,20 @@ class EventsStore(SQLBaseStore):
if context.rejected: if context.rejected:
self._store_rejections_txn( self._store_rejections_txn(
txn, invalidates, event.event_id, context.rejected txn, event.event_id, context.rejected
) )
for hash_alg, hash_base64 in event.hashes.items(): for hash_alg, hash_base64 in event.hashes.items():
hash_bytes = decode_base64(hash_base64) hash_bytes = decode_base64(hash_base64)
self._store_event_content_hash_txn( self._store_event_content_hash_txn(
txn, invalidates, event.event_id, hash_alg, hash_bytes, txn, event.event_id, hash_alg, hash_bytes,
) )
for prev_event_id, prev_hashes in event.prev_events: for prev_event_id, prev_hashes in event.prev_events:
for alg, hash_base64 in prev_hashes.items(): for alg, hash_base64 in prev_hashes.items():
hash_bytes = decode_base64(hash_base64) hash_bytes = decode_base64(hash_base64)
self._store_prev_event_hash_txn( self._store_prev_event_hash_txn(
txn, invalidates, event.event_id, prev_event_id, alg, txn, event.event_id, prev_event_id, alg,
hash_bytes hash_bytes
) )
@ -325,7 +319,7 @@ class EventsStore(SQLBaseStore):
(ref_alg, ref_hash_bytes) = compute_event_reference_hash(event) (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
self._store_event_reference_hash_txn( self._store_event_reference_hash_txn(
txn, invalidates, event.event_id, ref_alg, ref_hash_bytes txn, event.event_id, ref_alg, ref_hash_bytes
) )
if event.is_state(): if event.is_state():
@ -372,11 +366,11 @@ class EventsStore(SQLBaseStore):
} }
) )
return invalidates return
def _store_redaction(self, txn, invalidates, event): def _store_redaction(self, txn, event):
# invalidate the cache for the redacted event # invalidate the cache for the redacted event
invalidates.append((self._invalidate_get_event_cache, event.redacts)) txn.call_after(self._invalidate_get_event_cache, event.redacts)
txn.execute( txn.execute(
"INSERT INTO redactions (event_id, redacts) VALUES (?,?)", "INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
(event.event_id, event.redacts) (event.event_id, event.redacts)

View file

@ -162,7 +162,7 @@ class RoomStore(SQLBaseStore):
defer.returnValue(ret) defer.returnValue(ret)
def _store_room_topic_txn(self, txn, invalidates, event): def _store_room_topic_txn(self, txn, event):
if hasattr(event, "content") and "topic" in event.content: if hasattr(event, "content") and "topic" in event.content:
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
@ -174,7 +174,7 @@ class RoomStore(SQLBaseStore):
}, },
) )
def _store_room_name_txn(self, txn, invalidates, event): def _store_room_name_txn(self, txn, event):
if hasattr(event, "content") and "name" in event.content: if hasattr(event, "content") and "name" in event.content:
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,

View file

@ -35,7 +35,7 @@ RoomsForUser = namedtuple(
class RoomMemberStore(SQLBaseStore): class RoomMemberStore(SQLBaseStore):
def _store_room_member_txn(self, txn, invalidates, event): def _store_room_member_txn(self, txn, event):
"""Store a room member in the database. """Store a room member in the database.
""" """
try: try:
@ -64,10 +64,8 @@ class RoomMemberStore(SQLBaseStore):
} }
) )
invalidates.extend([ txn.call_after(self.get_rooms_for_user.invalidate, target_user_id)
(self.get_rooms_for_user.invalidate, target_user_id), txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
(self.get_joined_hosts_for_room.invalidate, event.room_id),
])
def get_room_member(self, user_id, room_id): def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member. """Retrieve the current state of a room member.

View file

@ -39,8 +39,8 @@ class SignatureStore(SQLBaseStore):
txn.execute(query, (event_id, )) txn.execute(query, (event_id, ))
return dict(txn.fetchall()) return dict(txn.fetchall())
def _store_event_content_hash_txn(self, txn, invalidates, event_id, def _store_event_content_hash_txn(self, txn, event_id, algorithm,
algorithm, hash_bytes): hash_bytes):
"""Store a hash for a Event """Store a hash for a Event
Args: Args:
txn (cursor): txn (cursor):
@ -101,8 +101,8 @@ class SignatureStore(SQLBaseStore):
txn.execute(query, (event_id, )) txn.execute(query, (event_id, ))
return {k: v for k, v in txn.fetchall()} return {k: v for k, v in txn.fetchall()}
def _store_event_reference_hash_txn(self, txn, invalidates, event_id, def _store_event_reference_hash_txn(self, txn, event_id, algorithm,
algorithm, hash_bytes): hash_bytes):
"""Store a hash for a PDU """Store a hash for a PDU
Args: Args:
txn (cursor): txn (cursor):
@ -184,8 +184,8 @@ class SignatureStore(SQLBaseStore):
hashes[algorithm] = hash_bytes hashes[algorithm] = hash_bytes
return results return results
def _store_prev_event_hash_txn(self, txn, invalidates, event_id, def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id,
prev_event_id, algorithm, hash_bytes): algorithm, hash_bytes):
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
"event_edge_hashes", "event_edge_hashes",

View file

@ -82,7 +82,7 @@ class StateStore(SQLBaseStore):
f, f,
) )
def _store_state_groups_txn(self, txn, invalidates, event, context): def _store_state_groups_txn(self, txn, event, context):
if context.current_state is None: if context.current_state is None:
return return