mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-22 09:35:45 +03:00
Store state groups separately from events (#2784)
* Split state group persist into seperate storage func * Add per database engine code for state group id gen * Move store_state_group to StateReadStore This allows other workers to use it, and so resolve state. * Hook up store_state_group * Fix tests * Rename _store_mult_state_groups_txn * Rename StateGroupReadStore * Remove redundant _have_persisted_state_group_txn * Update comments * Comment compute_event_context * Set start val for state_group_id_seq ... otherwise we try to recreate old state groups * Update comments * Don't store state for outliers * Update comment * Update docstring as state groups are ints
This commit is contained in:
parent
b31bf0bb51
commit
3d33eef6fc
12 changed files with 341 additions and 204 deletions
|
@ -25,7 +25,9 @@ class EventContext(object):
|
||||||
The current state map excluding the current event.
|
The current state map excluding the current event.
|
||||||
(type, state_key) -> event_id
|
(type, state_key) -> event_id
|
||||||
|
|
||||||
state_group (int): state group id
|
state_group (int|None): state group id, if the state has been stored
|
||||||
|
as a state group. This is usually only None if e.g. the event is
|
||||||
|
an outlier.
|
||||||
rejected (bool|str): A rejection reason if the event was rejected, else
|
rejected (bool|str): A rejection reason if the event was rejected, else
|
||||||
False
|
False
|
||||||
|
|
||||||
|
|
|
@ -1831,8 +1831,8 @@ class FederationHandler(BaseHandler):
|
||||||
current_state = set(e.event_id for e in auth_events.values())
|
current_state = set(e.event_id for e in auth_events.values())
|
||||||
different_auth = event_auth_events - current_state
|
different_auth = event_auth_events - current_state
|
||||||
|
|
||||||
self._update_context_for_auth_events(
|
yield self._update_context_for_auth_events(
|
||||||
context, auth_events, event_key,
|
event, context, auth_events, event_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
if different_auth and not event.internal_metadata.is_outlier():
|
if different_auth and not event.internal_metadata.is_outlier():
|
||||||
|
@ -1913,8 +1913,8 @@ class FederationHandler(BaseHandler):
|
||||||
# 4. Look at rejects and their proofs.
|
# 4. Look at rejects and their proofs.
|
||||||
# TODO.
|
# TODO.
|
||||||
|
|
||||||
self._update_context_for_auth_events(
|
yield self._update_context_for_auth_events(
|
||||||
context, auth_events, event_key,
|
event, context, auth_events, event_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -1923,11 +1923,15 @@ class FederationHandler(BaseHandler):
|
||||||
logger.warn("Failed auth resolution for %r because %s", event, e)
|
logger.warn("Failed auth resolution for %r because %s", event, e)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _update_context_for_auth_events(self, context, auth_events,
|
@defer.inlineCallbacks
|
||||||
|
def _update_context_for_auth_events(self, event, context, auth_events,
|
||||||
event_key):
|
event_key):
|
||||||
"""Update the state_ids in an event context after auth event resolution
|
"""Update the state_ids in an event context after auth event resolution,
|
||||||
|
storing the changes as a new state group.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
event (Event): The event we're handling the context for
|
||||||
|
|
||||||
context (synapse.events.snapshot.EventContext): event context
|
context (synapse.events.snapshot.EventContext): event context
|
||||||
to be updated
|
to be updated
|
||||||
|
|
||||||
|
@ -1950,7 +1954,13 @@ class FederationHandler(BaseHandler):
|
||||||
context.prev_state_ids.update({
|
context.prev_state_ids.update({
|
||||||
k: a.event_id for k, a in auth_events.iteritems()
|
k: a.event_id for k, a in auth_events.iteritems()
|
||||||
})
|
})
|
||||||
context.state_group = self.store.get_next_state_group()
|
context.state_group = yield self.store.store_state_group(
|
||||||
|
event.event_id,
|
||||||
|
event.room_id,
|
||||||
|
prev_group=context.prev_group,
|
||||||
|
delta_ids=context.delta_ids,
|
||||||
|
current_state_ids=context.current_state_ids,
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def construct_auth_difference(self, local_auth, remote_auth):
|
def construct_auth_difference(self, local_auth, remote_auth):
|
||||||
|
|
|
@ -19,7 +19,7 @@ from synapse.storage import DataStore
|
||||||
from synapse.storage.event_federation import EventFederationStore
|
from synapse.storage.event_federation import EventFederationStore
|
||||||
from synapse.storage.event_push_actions import EventPushActionsStore
|
from synapse.storage.event_push_actions import EventPushActionsStore
|
||||||
from synapse.storage.roommember import RoomMemberStore
|
from synapse.storage.roommember import RoomMemberStore
|
||||||
from synapse.storage.state import StateGroupReadStore
|
from synapse.storage.state import StateGroupWorkerStore
|
||||||
from synapse.storage.stream import StreamStore
|
from synapse.storage.stream import StreamStore
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
from ._base import BaseSlavedStore
|
from ._base import BaseSlavedStore
|
||||||
|
@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
|
||||||
# the method descriptor on the DataStore and chuck them into our class.
|
# the method descriptor on the DataStore and chuck them into our class.
|
||||||
|
|
||||||
|
|
||||||
class SlavedEventStore(StateGroupReadStore, BaseSlavedStore):
|
class SlavedEventStore(StateGroupWorkerStore, BaseSlavedStore):
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(SlavedEventStore, self).__init__(db_conn, hs)
|
super(SlavedEventStore, self).__init__(db_conn, hs)
|
||||||
|
|
|
@ -183,8 +183,15 @@ class StateHandler(object):
|
||||||
def compute_event_context(self, event, old_state=None):
|
def compute_event_context(self, event, old_state=None):
|
||||||
"""Build an EventContext structure for the event.
|
"""Build an EventContext structure for the event.
|
||||||
|
|
||||||
|
This works out what the current state should be for the event, and
|
||||||
|
generates a new state group if necessary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event (synapse.events.EventBase):
|
event (synapse.events.EventBase):
|
||||||
|
old_state (dict|None): The state at the event if it can't be
|
||||||
|
calculated from existing events. This is normally only specified
|
||||||
|
when receiving an event from federation where we don't have the
|
||||||
|
prev events for, e.g. when backfilling.
|
||||||
Returns:
|
Returns:
|
||||||
synapse.events.snapshot.EventContext:
|
synapse.events.snapshot.EventContext:
|
||||||
"""
|
"""
|
||||||
|
@ -208,15 +215,22 @@ class StateHandler(object):
|
||||||
context.current_state_ids = {}
|
context.current_state_ids = {}
|
||||||
context.prev_state_ids = {}
|
context.prev_state_ids = {}
|
||||||
context.prev_state_events = []
|
context.prev_state_events = []
|
||||||
context.state_group = self.store.get_next_state_group()
|
|
||||||
|
# We don't store state for outliers, so we don't generate a state
|
||||||
|
# froup for it.
|
||||||
|
context.state_group = None
|
||||||
|
|
||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
|
||||||
if old_state:
|
if old_state:
|
||||||
|
# We already have the state, so we don't need to calculate it.
|
||||||
|
# Let's just correctly fill out the context and create a
|
||||||
|
# new state group for it.
|
||||||
|
|
||||||
context = EventContext()
|
context = EventContext()
|
||||||
context.prev_state_ids = {
|
context.prev_state_ids = {
|
||||||
(s.type, s.state_key): s.event_id for s in old_state
|
(s.type, s.state_key): s.event_id for s in old_state
|
||||||
}
|
}
|
||||||
context.state_group = self.store.get_next_state_group()
|
|
||||||
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
key = (event.type, event.state_key)
|
key = (event.type, event.state_key)
|
||||||
|
@ -229,6 +243,14 @@ class StateHandler(object):
|
||||||
else:
|
else:
|
||||||
context.current_state_ids = context.prev_state_ids
|
context.current_state_ids = context.prev_state_ids
|
||||||
|
|
||||||
|
context.state_group = yield self.store.store_state_group(
|
||||||
|
event.event_id,
|
||||||
|
event.room_id,
|
||||||
|
prev_group=None,
|
||||||
|
delta_ids=None,
|
||||||
|
current_state_ids=context.current_state_ids,
|
||||||
|
)
|
||||||
|
|
||||||
context.prev_state_events = []
|
context.prev_state_events = []
|
||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
|
||||||
|
@ -242,7 +264,8 @@ class StateHandler(object):
|
||||||
context = EventContext()
|
context = EventContext()
|
||||||
context.prev_state_ids = curr_state
|
context.prev_state_ids = curr_state
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
context.state_group = self.store.get_next_state_group()
|
# If this is a state event then we need to create a new state
|
||||||
|
# group for the state after this event.
|
||||||
|
|
||||||
key = (event.type, event.state_key)
|
key = (event.type, event.state_key)
|
||||||
if key in context.prev_state_ids:
|
if key in context.prev_state_ids:
|
||||||
|
@ -253,24 +276,43 @@ class StateHandler(object):
|
||||||
context.current_state_ids[key] = event.event_id
|
context.current_state_ids[key] = event.event_id
|
||||||
|
|
||||||
if entry.state_group:
|
if entry.state_group:
|
||||||
|
# If the state at the event has a state group assigned then
|
||||||
|
# we can use that as the prev group
|
||||||
context.prev_group = entry.state_group
|
context.prev_group = entry.state_group
|
||||||
context.delta_ids = {
|
context.delta_ids = {
|
||||||
key: event.event_id
|
key: event.event_id
|
||||||
}
|
}
|
||||||
elif entry.prev_group:
|
elif entry.prev_group:
|
||||||
|
# If the state at the event only has a prev group, then we can
|
||||||
|
# use that as a prev group too.
|
||||||
context.prev_group = entry.prev_group
|
context.prev_group = entry.prev_group
|
||||||
context.delta_ids = dict(entry.delta_ids)
|
context.delta_ids = dict(entry.delta_ids)
|
||||||
context.delta_ids[key] = event.event_id
|
context.delta_ids[key] = event.event_id
|
||||||
else:
|
|
||||||
if entry.state_group is None:
|
|
||||||
entry.state_group = self.store.get_next_state_group()
|
|
||||||
entry.state_id = entry.state_group
|
|
||||||
|
|
||||||
context.state_group = entry.state_group
|
context.state_group = yield self.store.store_state_group(
|
||||||
|
event.event_id,
|
||||||
|
event.room_id,
|
||||||
|
prev_group=context.prev_group,
|
||||||
|
delta_ids=context.delta_ids,
|
||||||
|
current_state_ids=context.current_state_ids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
context.current_state_ids = context.prev_state_ids
|
context.current_state_ids = context.prev_state_ids
|
||||||
context.prev_group = entry.prev_group
|
context.prev_group = entry.prev_group
|
||||||
context.delta_ids = entry.delta_ids
|
context.delta_ids = entry.delta_ids
|
||||||
|
|
||||||
|
if entry.state_group is None:
|
||||||
|
entry.state_group = yield self.store.store_state_group(
|
||||||
|
event.event_id,
|
||||||
|
event.room_id,
|
||||||
|
prev_group=entry.prev_group,
|
||||||
|
delta_ids=entry.delta_ids,
|
||||||
|
current_state_ids=context.current_state_ids,
|
||||||
|
)
|
||||||
|
entry.state_id = entry.state_group
|
||||||
|
|
||||||
|
context.state_group = entry.state_group
|
||||||
|
|
||||||
context.prev_state_events = []
|
context.prev_state_events = []
|
||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
|
||||||
|
|
|
@ -124,7 +124,6 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
|
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
|
||||||
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
|
|
||||||
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
||||||
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
|
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
|
||||||
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
||||||
|
|
|
@ -62,3 +62,9 @@ class PostgresEngine(object):
|
||||||
|
|
||||||
def lock_table(self, txn, table):
|
def lock_table(self, txn, table):
|
||||||
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
|
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
|
||||||
|
|
||||||
|
def get_next_state_group_id(self, txn):
|
||||||
|
"""Returns an int that can be used as a new state_group ID
|
||||||
|
"""
|
||||||
|
txn.execute("SELECT nextval('state_group_id_seq')")
|
||||||
|
return txn.fetchone()[0]
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
from synapse.storage.prepare_database import prepare_database
|
from synapse.storage.prepare_database import prepare_database
|
||||||
|
|
||||||
import struct
|
import struct
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
class Sqlite3Engine(object):
|
class Sqlite3Engine(object):
|
||||||
|
@ -24,6 +25,11 @@ class Sqlite3Engine(object):
|
||||||
def __init__(self, database_module, database_config):
|
def __init__(self, database_module, database_config):
|
||||||
self.module = database_module
|
self.module = database_module
|
||||||
|
|
||||||
|
# The current max state_group, or None if we haven't looked
|
||||||
|
# in the DB yet.
|
||||||
|
self._current_state_group_id = None
|
||||||
|
self._current_state_group_id_lock = threading.Lock()
|
||||||
|
|
||||||
def check_database(self, txn):
|
def check_database(self, txn):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -43,6 +49,19 @@ class Sqlite3Engine(object):
|
||||||
def lock_table(self, txn, table):
|
def lock_table(self, txn, table):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def get_next_state_group_id(self, txn):
|
||||||
|
"""Returns an int that can be used as a new state_group ID
|
||||||
|
"""
|
||||||
|
# We do application locking here since if we're using sqlite then
|
||||||
|
# we are a single process synapse.
|
||||||
|
with self._current_state_group_id_lock:
|
||||||
|
if self._current_state_group_id is None:
|
||||||
|
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
|
||||||
|
self._current_state_group_id = txn.fetchone()[0]
|
||||||
|
|
||||||
|
self._current_state_group_id += 1
|
||||||
|
return self._current_state_group_id
|
||||||
|
|
||||||
|
|
||||||
# Following functions taken from: https://github.com/coleifer/peewee
|
# Following functions taken from: https://github.com/coleifer/peewee
|
||||||
|
|
||||||
|
|
|
@ -755,9 +755,8 @@ class EventsStore(SQLBaseStore):
|
||||||
events_and_contexts=events_and_contexts,
|
events_and_contexts=events_and_contexts,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Insert into the state_groups, state_groups_state, and
|
# Insert into event_to_state_groups.
|
||||||
# event_to_state_groups tables.
|
self._store_event_state_mappings_txn(txn, events_and_contexts)
|
||||||
self._store_mult_state_groups_txn(txn, events_and_contexts)
|
|
||||||
|
|
||||||
# _store_rejected_events_txn filters out any events which were
|
# _store_rejected_events_txn filters out any events which were
|
||||||
# rejected, and returns the filtered list.
|
# rejected, and returns the filtered list.
|
||||||
|
@ -992,10 +991,9 @@ class EventsStore(SQLBaseStore):
|
||||||
# an outlier in the database. We now have some state at that
|
# an outlier in the database. We now have some state at that
|
||||||
# so we need to update the state_groups table with that state.
|
# so we need to update the state_groups table with that state.
|
||||||
|
|
||||||
# insert into the state_group, state_groups_state and
|
# insert into event_to_state_groups.
|
||||||
# event_to_state_groups tables.
|
|
||||||
try:
|
try:
|
||||||
self._store_mult_state_groups_txn(txn, ((event, context),))
|
self._store_event_state_mappings_txn(txn, ((event, context),))
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("")
|
logger.exception("")
|
||||||
raise
|
raise
|
||||||
|
|
37
synapse/storage/schema/delta/47/state_group_seq.py
Normal file
37
synapse/storage/schema/delta/47/state_group_seq.py
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.storage.engines import PostgresEngine
|
||||||
|
|
||||||
|
|
||||||
|
def run_create(cur, database_engine, *args, **kwargs):
|
||||||
|
if isinstance(database_engine, PostgresEngine):
|
||||||
|
# if we already have some state groups, we want to start making new
|
||||||
|
# ones with a higher id.
|
||||||
|
cur.execute("SELECT max(id) FROM state_groups")
|
||||||
|
row = cur.fetchone()
|
||||||
|
|
||||||
|
if row[0] is None:
|
||||||
|
start_val = 1
|
||||||
|
else:
|
||||||
|
start_val = row[0] + 1
|
||||||
|
|
||||||
|
cur.execute(
|
||||||
|
"CREATE SEQUENCE state_group_id_seq START WITH %s",
|
||||||
|
(start_val, ),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_upgrade(*args, **kwargs):
|
||||||
|
pass
|
|
@ -42,11 +42,8 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
|
||||||
return len(self.delta_ids) if self.delta_ids else 0
|
return len(self.delta_ids) if self.delta_ids else 0
|
||||||
|
|
||||||
|
|
||||||
class StateGroupReadStore(SQLBaseStore):
|
class StateGroupWorkerStore(SQLBaseStore):
|
||||||
"""The read-only parts of StateGroupStore
|
"""The parts of StateGroupStore that can be called from workers.
|
||||||
|
|
||||||
None of these functions write to the state tables, so are suitable for
|
|
||||||
including in the SlavedStores.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
|
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
|
||||||
|
@ -54,7 +51,7 @@ class StateGroupReadStore(SQLBaseStore):
|
||||||
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
|
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(StateGroupReadStore, self).__init__(db_conn, hs)
|
super(StateGroupWorkerStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
self._state_group_cache = DictionaryCache(
|
self._state_group_cache = DictionaryCache(
|
||||||
"*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
|
"*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
|
||||||
|
@ -549,8 +546,117 @@ class StateGroupReadStore(SQLBaseStore):
|
||||||
|
|
||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
|
||||||
|
def store_state_group(self, event_id, room_id, prev_group, delta_ids,
|
||||||
|
current_state_ids):
|
||||||
|
"""Store a new set of state, returning a newly assigned state group.
|
||||||
|
|
||||||
class StateStore(StateGroupReadStore, BackgroundUpdateStore):
|
Args:
|
||||||
|
event_id (str): The event ID for which the state was calculated
|
||||||
|
room_id (str)
|
||||||
|
prev_group (int|None): A previous state group for the room, optional.
|
||||||
|
delta_ids (dict|None): The delta between state at `prev_group` and
|
||||||
|
`current_state_ids`, if `prev_group` was given. Same format as
|
||||||
|
`current_state_ids`.
|
||||||
|
current_state_ids (dict): The state to store. Map of (type, state_key)
|
||||||
|
to event_id.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[int]: The state group ID
|
||||||
|
"""
|
||||||
|
def _store_state_group_txn(txn):
|
||||||
|
if current_state_ids is None:
|
||||||
|
# AFAIK, this can never happen
|
||||||
|
raise Exception("current_state_ids cannot be None")
|
||||||
|
|
||||||
|
state_group = self.database_engine.get_next_state_group_id(txn)
|
||||||
|
|
||||||
|
self._simple_insert_txn(
|
||||||
|
txn,
|
||||||
|
table="state_groups",
|
||||||
|
values={
|
||||||
|
"id": state_group,
|
||||||
|
"room_id": room_id,
|
||||||
|
"event_id": event_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# We persist as a delta if we can, while also ensuring the chain
|
||||||
|
# of deltas isn't tooo long, as otherwise read performance degrades.
|
||||||
|
if prev_group:
|
||||||
|
is_in_db = self._simple_select_one_onecol_txn(
|
||||||
|
txn,
|
||||||
|
table="state_groups",
|
||||||
|
keyvalues={"id": prev_group},
|
||||||
|
retcol="id",
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
if not is_in_db:
|
||||||
|
raise Exception(
|
||||||
|
"Trying to persist state with unpersisted prev_group: %r"
|
||||||
|
% (prev_group,)
|
||||||
|
)
|
||||||
|
|
||||||
|
potential_hops = self._count_state_group_hops_txn(
|
||||||
|
txn, prev_group
|
||||||
|
)
|
||||||
|
if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
|
||||||
|
self._simple_insert_txn(
|
||||||
|
txn,
|
||||||
|
table="state_group_edges",
|
||||||
|
values={
|
||||||
|
"state_group": state_group,
|
||||||
|
"prev_state_group": prev_group,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self._simple_insert_many_txn(
|
||||||
|
txn,
|
||||||
|
table="state_groups_state",
|
||||||
|
values=[
|
||||||
|
{
|
||||||
|
"state_group": state_group,
|
||||||
|
"room_id": room_id,
|
||||||
|
"type": key[0],
|
||||||
|
"state_key": key[1],
|
||||||
|
"event_id": state_id,
|
||||||
|
}
|
||||||
|
for key, state_id in delta_ids.iteritems()
|
||||||
|
],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._simple_insert_many_txn(
|
||||||
|
txn,
|
||||||
|
table="state_groups_state",
|
||||||
|
values=[
|
||||||
|
{
|
||||||
|
"state_group": state_group,
|
||||||
|
"room_id": room_id,
|
||||||
|
"type": key[0],
|
||||||
|
"state_key": key[1],
|
||||||
|
"event_id": state_id,
|
||||||
|
}
|
||||||
|
for key, state_id in current_state_ids.iteritems()
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefill the state group cache with this group.
|
||||||
|
# It's fine to use the sequence like this as the state group map
|
||||||
|
# is immutable. (If the map wasn't immutable then this prefill could
|
||||||
|
# race with another update)
|
||||||
|
txn.call_after(
|
||||||
|
self._state_group_cache.update,
|
||||||
|
self._state_group_cache.sequence,
|
||||||
|
key=state_group,
|
||||||
|
value=dict(current_state_ids),
|
||||||
|
full=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return state_group
|
||||||
|
|
||||||
|
return self.runInteraction("store_state_group", _store_state_group_txn)
|
||||||
|
|
||||||
|
|
||||||
|
class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
|
||||||
""" Keeps track of the state at a given event.
|
""" Keeps track of the state at a given event.
|
||||||
|
|
||||||
This is done by the concept of `state groups`. Every event is a assigned
|
This is done by the concept of `state groups`. Every event is a assigned
|
||||||
|
@ -591,27 +697,12 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
|
||||||
where_clause="type='m.room.member'",
|
where_clause="type='m.room.member'",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _have_persisted_state_group_txn(self, txn, state_group):
|
def _store_event_state_mappings_txn(self, txn, events_and_contexts):
|
||||||
txn.execute(
|
|
||||||
"SELECT count(*) FROM state_groups WHERE id = ?",
|
|
||||||
(state_group,)
|
|
||||||
)
|
|
||||||
row = txn.fetchone()
|
|
||||||
return row and row[0]
|
|
||||||
|
|
||||||
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
|
|
||||||
state_groups = {}
|
state_groups = {}
|
||||||
for event, context in events_and_contexts:
|
for event, context in events_and_contexts:
|
||||||
if event.internal_metadata.is_outlier():
|
if event.internal_metadata.is_outlier():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if context.current_state_ids is None:
|
|
||||||
# AFAIK, this can never happen
|
|
||||||
logger.error(
|
|
||||||
"Non-outlier event %s had current_state_ids==None",
|
|
||||||
event.event_id)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# if the event was rejected, just give it the same state as its
|
# if the event was rejected, just give it the same state as its
|
||||||
# predecessor.
|
# predecessor.
|
||||||
if context.rejected:
|
if context.rejected:
|
||||||
|
@ -620,90 +711,6 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
|
||||||
|
|
||||||
state_groups[event.event_id] = context.state_group
|
state_groups[event.event_id] = context.state_group
|
||||||
|
|
||||||
if self._have_persisted_state_group_txn(txn, context.state_group):
|
|
||||||
continue
|
|
||||||
|
|
||||||
self._simple_insert_txn(
|
|
||||||
txn,
|
|
||||||
table="state_groups",
|
|
||||||
values={
|
|
||||||
"id": context.state_group,
|
|
||||||
"room_id": event.room_id,
|
|
||||||
"event_id": event.event_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# We persist as a delta if we can, while also ensuring the chain
|
|
||||||
# of deltas isn't tooo long, as otherwise read performance degrades.
|
|
||||||
if context.prev_group:
|
|
||||||
is_in_db = self._simple_select_one_onecol_txn(
|
|
||||||
txn,
|
|
||||||
table="state_groups",
|
|
||||||
keyvalues={"id": context.prev_group},
|
|
||||||
retcol="id",
|
|
||||||
allow_none=True,
|
|
||||||
)
|
|
||||||
if not is_in_db:
|
|
||||||
raise Exception(
|
|
||||||
"Trying to persist state with unpersisted prev_group: %r"
|
|
||||||
% (context.prev_group,)
|
|
||||||
)
|
|
||||||
|
|
||||||
potential_hops = self._count_state_group_hops_txn(
|
|
||||||
txn, context.prev_group
|
|
||||||
)
|
|
||||||
if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
|
|
||||||
self._simple_insert_txn(
|
|
||||||
txn,
|
|
||||||
table="state_group_edges",
|
|
||||||
values={
|
|
||||||
"state_group": context.state_group,
|
|
||||||
"prev_state_group": context.prev_group,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
self._simple_insert_many_txn(
|
|
||||||
txn,
|
|
||||||
table="state_groups_state",
|
|
||||||
values=[
|
|
||||||
{
|
|
||||||
"state_group": context.state_group,
|
|
||||||
"room_id": event.room_id,
|
|
||||||
"type": key[0],
|
|
||||||
"state_key": key[1],
|
|
||||||
"event_id": state_id,
|
|
||||||
}
|
|
||||||
for key, state_id in context.delta_ids.iteritems()
|
|
||||||
],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._simple_insert_many_txn(
|
|
||||||
txn,
|
|
||||||
table="state_groups_state",
|
|
||||||
values=[
|
|
||||||
{
|
|
||||||
"state_group": context.state_group,
|
|
||||||
"room_id": event.room_id,
|
|
||||||
"type": key[0],
|
|
||||||
"state_key": key[1],
|
|
||||||
"event_id": state_id,
|
|
||||||
}
|
|
||||||
for key, state_id in context.current_state_ids.iteritems()
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prefill the state group cache with this group.
|
|
||||||
# It's fine to use the sequence like this as the state group map
|
|
||||||
# is immutable. (If the map wasn't immutable then this prefill could
|
|
||||||
# race with another update)
|
|
||||||
txn.call_after(
|
|
||||||
self._state_group_cache.update,
|
|
||||||
self._state_group_cache.sequence,
|
|
||||||
key=context.state_group,
|
|
||||||
value=dict(context.current_state_ids),
|
|
||||||
full=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._simple_insert_many_txn(
|
self._simple_insert_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="event_to_state_groups",
|
table="event_to_state_groups",
|
||||||
|
@ -763,9 +770,6 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
|
||||||
|
|
||||||
return count
|
return count
|
||||||
|
|
||||||
def get_next_state_group(self):
|
|
||||||
return self._state_groups_id_gen.get_next()
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _background_deduplicate_state(self, progress, batch_size):
|
def _background_deduplicate_state(self, progress, batch_size):
|
||||||
"""This background update will slowly deduplicate state by reencoding
|
"""This background update will slowly deduplicate state by reencoding
|
||||||
|
|
|
@ -226,11 +226,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
context = EventContext()
|
context = EventContext()
|
||||||
context.current_state_ids = state_ids
|
context.current_state_ids = state_ids
|
||||||
context.prev_state_ids = state_ids
|
context.prev_state_ids = state_ids
|
||||||
elif not backfill:
|
else:
|
||||||
state_handler = self.hs.get_state_handler()
|
state_handler = self.hs.get_state_handler()
|
||||||
context = yield state_handler.compute_event_context(event)
|
context = yield state_handler.compute_event_context(event)
|
||||||
else:
|
|
||||||
context = EventContext()
|
|
||||||
|
|
||||||
context.push_actions = push_actions
|
context.push_actions = push_actions
|
||||||
|
|
||||||
|
|
|
@ -80,14 +80,14 @@ class StateGroupStore(object):
|
||||||
|
|
||||||
return defer.succeed(groups)
|
return defer.succeed(groups)
|
||||||
|
|
||||||
def store_state_groups(self, event, context):
|
def store_state_group(self, event_id, room_id, prev_group, delta_ids,
|
||||||
if context.current_state_ids is None:
|
current_state_ids):
|
||||||
return
|
state_group = self._next_group
|
||||||
|
self._next_group += 1
|
||||||
|
|
||||||
state_events = dict(context.current_state_ids)
|
self._group_to_state[state_group] = dict(current_state_ids)
|
||||||
|
|
||||||
self._group_to_state[context.state_group] = state_events
|
return state_group
|
||||||
self._event_to_state_group[event.event_id] = context.state_group
|
|
||||||
|
|
||||||
def get_events(self, event_ids, **kwargs):
|
def get_events(self, event_ids, **kwargs):
|
||||||
return {
|
return {
|
||||||
|
@ -95,10 +95,19 @@ class StateGroupStore(object):
|
||||||
if e_id in self._event_id_to_event
|
if e_id in self._event_id_to_event
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_state_group_delta(self, name):
|
||||||
|
return (None, None)
|
||||||
|
|
||||||
def register_events(self, events):
|
def register_events(self, events):
|
||||||
for e in events:
|
for e in events:
|
||||||
self._event_id_to_event[e.event_id] = e
|
self._event_id_to_event[e.event_id] = e
|
||||||
|
|
||||||
|
def register_event_context(self, event, context):
|
||||||
|
self._event_to_state_group[event.event_id] = context.state_group
|
||||||
|
|
||||||
|
def register_event_id_state_group(self, event_id, state_group):
|
||||||
|
self._event_to_state_group[event_id] = state_group
|
||||||
|
|
||||||
|
|
||||||
class DictObj(dict):
|
class DictObj(dict):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
|
@ -137,15 +146,7 @@ class Graph(object):
|
||||||
|
|
||||||
class StateTestCase(unittest.TestCase):
|
class StateTestCase(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.store = Mock(
|
self.store = StateGroupStore()
|
||||||
spec_set=[
|
|
||||||
"get_state_groups_ids",
|
|
||||||
"add_event_hashes",
|
|
||||||
"get_events",
|
|
||||||
"get_next_state_group",
|
|
||||||
"get_state_group_delta",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
hs = Mock(spec_set=[
|
hs = Mock(spec_set=[
|
||||||
"get_datastore", "get_auth", "get_state_handler", "get_clock",
|
"get_datastore", "get_auth", "get_state_handler", "get_clock",
|
||||||
"get_state_resolution_handler",
|
"get_state_resolution_handler",
|
||||||
|
@ -156,9 +157,6 @@ class StateTestCase(unittest.TestCase):
|
||||||
hs.get_auth.return_value = Auth(hs)
|
hs.get_auth.return_value = Auth(hs)
|
||||||
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
|
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
|
||||||
|
|
||||||
self.store.get_next_state_group.side_effect = Mock
|
|
||||||
self.store.get_state_group_delta.return_value = (None, None)
|
|
||||||
|
|
||||||
self.state = StateHandler(hs)
|
self.state = StateHandler(hs)
|
||||||
self.event_id = 0
|
self.event_id = 0
|
||||||
|
|
||||||
|
@ -197,14 +195,13 @@ class StateTestCase(unittest.TestCase):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
store = StateGroupStore()
|
self.store.register_events(graph.walk())
|
||||||
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
|
||||||
|
|
||||||
context_store = {}
|
context_store = {}
|
||||||
|
|
||||||
for event in graph.walk():
|
for event in graph.walk():
|
||||||
context = yield self.state.compute_event_context(event)
|
context = yield self.state.compute_event_context(event)
|
||||||
store.store_state_groups(event, context)
|
self.store.register_event_context(event, context)
|
||||||
context_store[event.event_id] = context
|
context_store[event.event_id] = context
|
||||||
|
|
||||||
self.assertEqual(2, len(context_store["D"].prev_state_ids))
|
self.assertEqual(2, len(context_store["D"].prev_state_ids))
|
||||||
|
@ -249,16 +246,13 @@ class StateTestCase(unittest.TestCase):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
store = StateGroupStore()
|
self.store.register_events(graph.walk())
|
||||||
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
|
||||||
self.store.get_events = store.get_events
|
|
||||||
store.register_events(graph.walk())
|
|
||||||
|
|
||||||
context_store = {}
|
context_store = {}
|
||||||
|
|
||||||
for event in graph.walk():
|
for event in graph.walk():
|
||||||
context = yield self.state.compute_event_context(event)
|
context = yield self.state.compute_event_context(event)
|
||||||
store.store_state_groups(event, context)
|
self.store.register_event_context(event, context)
|
||||||
context_store[event.event_id] = context
|
context_store[event.event_id] = context
|
||||||
|
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
|
@ -315,16 +309,13 @@ class StateTestCase(unittest.TestCase):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
store = StateGroupStore()
|
self.store.register_events(graph.walk())
|
||||||
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
|
||||||
self.store.get_events = store.get_events
|
|
||||||
store.register_events(graph.walk())
|
|
||||||
|
|
||||||
context_store = {}
|
context_store = {}
|
||||||
|
|
||||||
for event in graph.walk():
|
for event in graph.walk():
|
||||||
context = yield self.state.compute_event_context(event)
|
context = yield self.state.compute_event_context(event)
|
||||||
store.store_state_groups(event, context)
|
self.store.register_event_context(event, context)
|
||||||
context_store[event.event_id] = context
|
context_store[event.event_id] = context
|
||||||
|
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
|
@ -398,16 +389,13 @@ class StateTestCase(unittest.TestCase):
|
||||||
self._add_depths(nodes, edges)
|
self._add_depths(nodes, edges)
|
||||||
graph = Graph(nodes, edges)
|
graph = Graph(nodes, edges)
|
||||||
|
|
||||||
store = StateGroupStore()
|
self.store.register_events(graph.walk())
|
||||||
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
|
||||||
self.store.get_events = store.get_events
|
|
||||||
store.register_events(graph.walk())
|
|
||||||
|
|
||||||
context_store = {}
|
context_store = {}
|
||||||
|
|
||||||
for event in graph.walk():
|
for event in graph.walk():
|
||||||
context = yield self.state.compute_event_context(event)
|
context = yield self.state.compute_event_context(event)
|
||||||
store.store_state_groups(event, context)
|
self.store.register_event_context(event, context)
|
||||||
context_store[event.event_id] = context
|
context_store[event.event_id] = context
|
||||||
|
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
|
@ -467,7 +455,11 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_trivial_annotate_message(self):
|
def test_trivial_annotate_message(self):
|
||||||
event = create_event(type="test_message", name="event")
|
prev_event_id = "prev_event_id"
|
||||||
|
event = create_event(
|
||||||
|
type="test_message", name="event2",
|
||||||
|
prev_events=[(prev_event_id, {})],
|
||||||
|
)
|
||||||
|
|
||||||
old_state = [
|
old_state = [
|
||||||
create_event(type="test1", state_key="1"),
|
create_event(type="test1", state_key="1"),
|
||||||
|
@ -475,11 +467,11 @@ class StateTestCase(unittest.TestCase):
|
||||||
create_event(type="test2", state_key=""),
|
create_event(type="test2", state_key=""),
|
||||||
]
|
]
|
||||||
|
|
||||||
group_name = "group_name_1"
|
group_name = self.store.store_state_group(
|
||||||
|
prev_event_id, event.room_id, None, None,
|
||||||
self.store.get_state_groups_ids.return_value = {
|
{(e.type, e.state_key): e.event_id for e in old_state},
|
||||||
group_name: {(e.type, e.state_key): e.event_id for e in old_state},
|
)
|
||||||
}
|
self.store.register_event_id_state_group(prev_event_id, group_name)
|
||||||
|
|
||||||
context = yield self.state.compute_event_context(event)
|
context = yield self.state.compute_event_context(event)
|
||||||
|
|
||||||
|
@ -492,7 +484,11 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_trivial_annotate_state(self):
|
def test_trivial_annotate_state(self):
|
||||||
event = create_event(type="state", state_key="", name="event")
|
prev_event_id = "prev_event_id"
|
||||||
|
event = create_event(
|
||||||
|
type="state", state_key="", name="event2",
|
||||||
|
prev_events=[(prev_event_id, {})],
|
||||||
|
)
|
||||||
|
|
||||||
old_state = [
|
old_state = [
|
||||||
create_event(type="test1", state_key="1"),
|
create_event(type="test1", state_key="1"),
|
||||||
|
@ -500,11 +496,11 @@ class StateTestCase(unittest.TestCase):
|
||||||
create_event(type="test2", state_key=""),
|
create_event(type="test2", state_key=""),
|
||||||
]
|
]
|
||||||
|
|
||||||
group_name = "group_name_1"
|
group_name = self.store.store_state_group(
|
||||||
|
prev_event_id, event.room_id, None, None,
|
||||||
self.store.get_state_groups_ids.return_value = {
|
{(e.type, e.state_key): e.event_id for e in old_state},
|
||||||
group_name: {(e.type, e.state_key): e.event_id for e in old_state},
|
)
|
||||||
}
|
self.store.register_event_id_state_group(prev_event_id, group_name)
|
||||||
|
|
||||||
context = yield self.state.compute_event_context(event)
|
context = yield self.state.compute_event_context(event)
|
||||||
|
|
||||||
|
@ -517,7 +513,12 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_resolve_message_conflict(self):
|
def test_resolve_message_conflict(self):
|
||||||
event = create_event(type="test_message", name="event")
|
prev_event_id1 = "event_id1"
|
||||||
|
prev_event_id2 = "event_id2"
|
||||||
|
event = create_event(
|
||||||
|
type="test_message", name="event3",
|
||||||
|
prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
|
||||||
|
)
|
||||||
|
|
||||||
creation = create_event(
|
creation = create_event(
|
||||||
type=EventTypes.Create, state_key=""
|
type=EventTypes.Create, state_key=""
|
||||||
|
@ -537,12 +538,12 @@ class StateTestCase(unittest.TestCase):
|
||||||
create_event(type="test4", state_key=""),
|
create_event(type="test4", state_key=""),
|
||||||
]
|
]
|
||||||
|
|
||||||
store = StateGroupStore()
|
self.store.register_events(old_state_1)
|
||||||
store.register_events(old_state_1)
|
self.store.register_events(old_state_2)
|
||||||
store.register_events(old_state_2)
|
|
||||||
self.store.get_events = store.get_events
|
|
||||||
|
|
||||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
context = yield self._get_context(
|
||||||
|
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(len(context.current_state_ids), 6)
|
self.assertEqual(len(context.current_state_ids), 6)
|
||||||
|
|
||||||
|
@ -550,7 +551,12 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_resolve_state_conflict(self):
|
def test_resolve_state_conflict(self):
|
||||||
event = create_event(type="test4", state_key="", name="event")
|
prev_event_id1 = "event_id1"
|
||||||
|
prev_event_id2 = "event_id2"
|
||||||
|
event = create_event(
|
||||||
|
type="test4", state_key="", name="event",
|
||||||
|
prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
|
||||||
|
)
|
||||||
|
|
||||||
creation = create_event(
|
creation = create_event(
|
||||||
type=EventTypes.Create, state_key=""
|
type=EventTypes.Create, state_key=""
|
||||||
|
@ -575,7 +581,9 @@ class StateTestCase(unittest.TestCase):
|
||||||
store.register_events(old_state_2)
|
store.register_events(old_state_2)
|
||||||
self.store.get_events = store.get_events
|
self.store.get_events = store.get_events
|
||||||
|
|
||||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
context = yield self._get_context(
|
||||||
|
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(len(context.current_state_ids), 6)
|
self.assertEqual(len(context.current_state_ids), 6)
|
||||||
|
|
||||||
|
@ -583,7 +591,12 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_standard_depth_conflict(self):
|
def test_standard_depth_conflict(self):
|
||||||
event = create_event(type="test4", name="event")
|
prev_event_id1 = "event_id1"
|
||||||
|
prev_event_id2 = "event_id2"
|
||||||
|
event = create_event(
|
||||||
|
type="test4", name="event",
|
||||||
|
prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
|
||||||
|
)
|
||||||
|
|
||||||
member_event = create_event(
|
member_event = create_event(
|
||||||
type=EventTypes.Member,
|
type=EventTypes.Member,
|
||||||
|
@ -615,7 +628,9 @@ class StateTestCase(unittest.TestCase):
|
||||||
store.register_events(old_state_2)
|
store.register_events(old_state_2)
|
||||||
self.store.get_events = store.get_events
|
self.store.get_events = store.get_events
|
||||||
|
|
||||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
context = yield self._get_context(
|
||||||
|
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
|
old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
|
||||||
|
@ -639,19 +654,26 @@ class StateTestCase(unittest.TestCase):
|
||||||
store.register_events(old_state_1)
|
store.register_events(old_state_1)
|
||||||
store.register_events(old_state_2)
|
store.register_events(old_state_2)
|
||||||
|
|
||||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
context = yield self._get_context(
|
||||||
|
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
|
old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_context(self, event, old_state_1, old_state_2):
|
def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,
|
||||||
group_name_1 = "group_name_1"
|
old_state_2):
|
||||||
group_name_2 = "group_name_2"
|
sg1 = self.store.store_state_group(
|
||||||
|
prev_event_id_1, event.room_id, None, None,
|
||||||
|
{(e.type, e.state_key): e.event_id for e in old_state_1},
|
||||||
|
)
|
||||||
|
self.store.register_event_id_state_group(prev_event_id_1, sg1)
|
||||||
|
|
||||||
self.store.get_state_groups_ids.return_value = {
|
sg2 = self.store.store_state_group(
|
||||||
group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1},
|
prev_event_id_2, event.room_id, None, None,
|
||||||
group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2},
|
{(e.type, e.state_key): e.event_id for e in old_state_2},
|
||||||
}
|
)
|
||||||
|
self.store.register_event_id_state_group(prev_event_id_2, sg2)
|
||||||
|
|
||||||
return self.state.compute_event_context(event)
|
return self.state.compute_event_context(event)
|
||||||
|
|
Loading…
Reference in a new issue