mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-21 17:15:38 +03:00
Store state
This commit is contained in:
parent
b3d8e2d2bd
commit
3838b18d3b
9 changed files with 601 additions and 103 deletions
|
@ -98,6 +98,7 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
|||
from synapse.storage.databases.main.search import SearchStore
|
||||
from synapse.storage.databases.main.session import SessionStore
|
||||
from synapse.storage.databases.main.signatures import SignatureWorkerStore
|
||||
from synapse.storage.databases.main.sliding_sync import SlidingSyncStore
|
||||
from synapse.storage.databases.main.state import StateGroupWorkerStore
|
||||
from synapse.storage.databases.main.stats import StatsStore
|
||||
from synapse.storage.databases.main.stream import StreamWorkerStore
|
||||
|
@ -159,6 +160,7 @@ class GenericWorkerStore(
|
|||
SessionStore,
|
||||
TaskSchedulerWorkerStore,
|
||||
ExperimentalFeaturesStore,
|
||||
SlidingSyncStore,
|
||||
):
|
||||
# Properties that multiple storage classes define. Tell mypy what the
|
||||
# expected type is.
|
||||
|
|
|
@ -208,7 +208,7 @@ class SlidingSyncHandler:
|
|||
self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
||||
self.connection_store = SlidingSyncConnectionStore()
|
||||
self.connection_store = SlidingSyncConnectionStore(self.store)
|
||||
self.extensions = SlidingSyncExtensionHandler(hs)
|
||||
|
||||
async def wait_for_sync_for_user(
|
||||
|
|
|
@ -13,12 +13,12 @@
|
|||
#
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.errors import SlidingSyncUnknownPosition
|
||||
from synapse.logging.opentracing import trace
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.types import SlidingSyncStreamToken
|
||||
from synapse.types.handlers.sliding_sync import (
|
||||
MutablePerConnectionState,
|
||||
|
@ -61,20 +61,7 @@ class SlidingSyncConnectionStore:
|
|||
to mapping of room ID to `HaveSentRoom`.
|
||||
"""
|
||||
|
||||
# `(user_id, conn_id)` -> `connection_position` -> `PerConnectionState`
|
||||
_connections: Dict[Tuple[str, str], Dict[int, PerConnectionState]] = attr.Factory(
|
||||
dict
|
||||
)
|
||||
|
||||
async def is_valid_token(
|
||||
self, sync_config: SlidingSyncConfig, connection_token: int
|
||||
) -> bool:
|
||||
"""Return whether the connection token is valid/recognized"""
|
||||
if connection_token == 0:
|
||||
return True
|
||||
|
||||
conn_key = self._get_connection_key(sync_config)
|
||||
return connection_token in self._connections.get(conn_key, {})
|
||||
store: "DataStore"
|
||||
|
||||
async def get_per_connection_state(
|
||||
self,
|
||||
|
@ -86,23 +73,20 @@ class SlidingSyncConnectionStore:
|
|||
Raises:
|
||||
SlidingSyncUnknownPosition if the connection_token is unknown
|
||||
"""
|
||||
if from_token is None:
|
||||
if from_token is None or from_token.connection_position == 0:
|
||||
return PerConnectionState()
|
||||
|
||||
connection_position = from_token.connection_position
|
||||
if connection_position == 0:
|
||||
# Initial sync (request without a `from_token`) starts at `0` so
|
||||
# there is no existing per-connection state
|
||||
return PerConnectionState()
|
||||
conn_id = sync_config.conn_id or ""
|
||||
|
||||
conn_key = self._get_connection_key(sync_config)
|
||||
sync_statuses = self._connections.get(conn_key, {})
|
||||
connection_state = sync_statuses.get(connection_position)
|
||||
device_id = sync_config.requester.device_id
|
||||
assert device_id is not None
|
||||
|
||||
if connection_state is None:
|
||||
raise SlidingSyncUnknownPosition()
|
||||
|
||||
return connection_state
|
||||
return await self.store.get_per_connection_state(
|
||||
sync_config.user.to_string(),
|
||||
device_id,
|
||||
conn_id,
|
||||
from_token.connection_position,
|
||||
)
|
||||
|
||||
@trace
|
||||
async def record_new_state(
|
||||
|
@ -116,26 +100,27 @@ class SlidingSyncConnectionStore:
|
|||
If there are no changes to the state this may return the same token as
|
||||
the existing per-connection state.
|
||||
"""
|
||||
prev_connection_token = 0
|
||||
if from_token is not None:
|
||||
prev_connection_token = from_token.connection_position
|
||||
|
||||
if not new_connection_state.has_updates():
|
||||
return prev_connection_token
|
||||
if from_token is not None:
|
||||
return from_token.connection_position
|
||||
else:
|
||||
return 0
|
||||
|
||||
conn_key = self._get_connection_key(sync_config)
|
||||
sync_statuses = self._connections.setdefault(conn_key, {})
|
||||
if from_token is not None and from_token.connection_position == 0:
|
||||
from_token = None
|
||||
|
||||
# Generate a new token, removing any existing entries in that token
|
||||
# (which can happen if requests get resent).
|
||||
new_store_token = prev_connection_token + 1
|
||||
sync_statuses.pop(new_store_token, None)
|
||||
conn_id = sync_config.conn_id or ""
|
||||
|
||||
# We copy the `MutablePerConnectionState` so that the inner `ChainMap`s
|
||||
# don't grow forever.
|
||||
sync_statuses[new_store_token] = new_connection_state.copy()
|
||||
device_id = sync_config.requester.device_id
|
||||
assert device_id is not None
|
||||
|
||||
return new_store_token
|
||||
return await self.store.persist_per_connection_state(
|
||||
sync_config.user.to_string(),
|
||||
device_id,
|
||||
conn_id,
|
||||
from_token.connection_position if from_token else None,
|
||||
new_connection_state,
|
||||
)
|
||||
|
||||
@trace
|
||||
async def mark_token_seen(
|
||||
|
@ -143,58 +128,4 @@ class SlidingSyncConnectionStore:
|
|||
sync_config: SlidingSyncConfig,
|
||||
from_token: Optional[SlidingSyncStreamToken],
|
||||
) -> None:
|
||||
"""We have received a request with the given token, so we can clear out
|
||||
any other tokens associated with the connection.
|
||||
|
||||
If there is no from token then we have started afresh, and so we delete
|
||||
all tokens associated with the device.
|
||||
"""
|
||||
# Clear out any tokens for the connection that doesn't match the one
|
||||
# from the request.
|
||||
|
||||
conn_key = self._get_connection_key(sync_config)
|
||||
sync_statuses = self._connections.pop(conn_key, {})
|
||||
if from_token is None:
|
||||
return
|
||||
|
||||
sync_statuses = {
|
||||
connection_token: room_statuses
|
||||
for connection_token, room_statuses in sync_statuses.items()
|
||||
if connection_token == from_token.connection_position
|
||||
}
|
||||
if sync_statuses:
|
||||
self._connections[conn_key] = sync_statuses
|
||||
|
||||
@staticmethod
|
||||
def _get_connection_key(sync_config: SlidingSyncConfig) -> Tuple[str, str]:
|
||||
"""Return a unique identifier for this connection.
|
||||
|
||||
The first part is simply the user ID.
|
||||
|
||||
The second part is generally a combination of device ID and conn_id.
|
||||
However, both these two are optional (e.g. puppet access tokens don't
|
||||
have device IDs), so this handles those edge cases.
|
||||
|
||||
We use this over the raw `conn_id` to avoid clashes between different
|
||||
clients that use the same `conn_id`. Imagine a user uses a web client
|
||||
that uses `conn_id: main_sync_loop` and an Android client that also has
|
||||
a `conn_id: main_sync_loop`.
|
||||
"""
|
||||
|
||||
user_id = sync_config.user.to_string()
|
||||
|
||||
# Only one sliding sync connection is allowed per given conn_id (empty
|
||||
# or not).
|
||||
conn_id = sync_config.conn_id or ""
|
||||
|
||||
if sync_config.requester.device_id:
|
||||
return (user_id, f"D/{sync_config.requester.device_id}/{conn_id}")
|
||||
|
||||
if sync_config.requester.access_token_id:
|
||||
# If we don't have a device, then the access token ID should be a
|
||||
# stable ID.
|
||||
return (user_id, f"A/{sync_config.requester.access_token_id}/{conn_id}")
|
||||
|
||||
# If we have neither then its likely an AS or some weird token. Either
|
||||
# way we can just fail here.
|
||||
raise Exception("Cannot use sliding sync with access token type")
|
||||
pass
|
||||
|
|
|
@ -33,6 +33,7 @@ from synapse.storage.database import (
|
|||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.databases.main.sliding_sync import SlidingSyncStore
|
||||
from synapse.storage.databases.main.stats import UserSortOrder
|
||||
from synapse.storage.engines import BaseDatabaseEngine
|
||||
from synapse.storage.types import Cursor
|
||||
|
@ -156,6 +157,7 @@ class DataStore(
|
|||
LockStore,
|
||||
SessionStore,
|
||||
TaskSchedulerWorkerStore,
|
||||
SlidingSyncStore,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
473
synapse/storage/databases/main/sliding_sync.py
Normal file
473
synapse/storage/databases/main/sliding_sync.py
Normal file
|
@ -0,0 +1,473 @@
|
|||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2023 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
#
|
||||
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, cast
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.errors import SlidingSyncUnknownPosition
|
||||
from synapse.logging.opentracing import log_kv
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.types import MultiWriterStreamToken, RoomStreamToken
|
||||
from synapse.types.handlers.sliding_sync import (
|
||||
HaveSentRoom,
|
||||
HaveSentRoomFlag,
|
||||
MutablePerConnectionState,
|
||||
PerConnectionState,
|
||||
RoomStatusMap,
|
||||
RoomSyncConfig,
|
||||
)
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
|
||||
class SlidingSyncStore(SQLBaseStore):
|
||||
async def persist_per_connection_state(
|
||||
self,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
conn_id: str,
|
||||
previous_connection_position: Optional[int],
|
||||
per_connection_state: "MutablePerConnectionState",
|
||||
) -> int:
|
||||
"""Persist updates to the per-connection state for a sliding sync
|
||||
connection.
|
||||
|
||||
Returns:
|
||||
The connection position of the newly persisted state.
|
||||
"""
|
||||
|
||||
store = cast("DataStore", self)
|
||||
return await self.db_pool.runInteraction(
|
||||
"persist_per_connection_state",
|
||||
self.persist_per_connection_state_txn,
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
conn_id=conn_id,
|
||||
previous_connection_position=previous_connection_position,
|
||||
per_connection_state=await PerConnectionStateDB.from_state(
|
||||
per_connection_state, store
|
||||
),
|
||||
)
|
||||
|
||||
def persist_per_connection_state_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
conn_id: str,
|
||||
previous_connection_position: Optional[int],
|
||||
per_connection_state: "PerConnectionStateDB",
|
||||
) -> int:
|
||||
# First we fetch the (or create) the connection key associated with the
|
||||
# previous connection position.
|
||||
if previous_connection_position is not None:
|
||||
# The `previous_connection_position` is a user-supplied value, so we
|
||||
# need to make sure that the one they supplied is actually theirs.
|
||||
sql = """
|
||||
SELECT connection_key
|
||||
FROM sliding_sync_connection_positions
|
||||
INNER JOIN sliding_sync_connections USING (connection_key)
|
||||
WHERE
|
||||
connection_position = ?
|
||||
AND user_id = ? AND device_id = ? AND conn_id = ?
|
||||
"""
|
||||
txn.execute(
|
||||
sql, (previous_connection_position, user_id, device_id, conn_id)
|
||||
)
|
||||
row = txn.fetchone()
|
||||
if row is None:
|
||||
raise SlidingSyncUnknownPosition()
|
||||
|
||||
(connection_key,) = row
|
||||
else:
|
||||
# We're restarting the connection, so we clear all existing
|
||||
# connections. We do this here to ensure that if we get lots of
|
||||
# one-shot requests we don't stack up lots of entries.
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
table="sliding_sync_connections",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"conn_id": conn_id,
|
||||
},
|
||||
)
|
||||
|
||||
(connection_key,) = self.db_pool.simple_insert_returning_txn(
|
||||
txn,
|
||||
table="sliding_sync_connections",
|
||||
values={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"conn_id": conn_id,
|
||||
"created_ts": self._clock.time_msec(),
|
||||
},
|
||||
returning=("connection_key",),
|
||||
)
|
||||
|
||||
# Define a new connection position for the updates
|
||||
(connection_position,) = self.db_pool.simple_insert_returning_txn(
|
||||
txn,
|
||||
table="sliding_sync_connection_positions",
|
||||
values={
|
||||
"connection_key": connection_key,
|
||||
"created_ts": self._clock.time_msec(),
|
||||
},
|
||||
returning=("connection_position",),
|
||||
)
|
||||
|
||||
# We need to deduplicate the `required_state` JSON. We do this by
|
||||
# fetching all JSON associated with the connection and comparing that
|
||||
# with the updates to `required_state`
|
||||
|
||||
# Dict from required state json -> required state ID
|
||||
required_state_to_id: Dict[str, int] = {}
|
||||
if previous_connection_position is not None:
|
||||
rows = self.db_pool.simple_select_list_txn(
|
||||
txn,
|
||||
table="sliding_sync_connection_required_state",
|
||||
keyvalues={"connection_key": connection_key},
|
||||
retcols=("required_state_id", "required_state"),
|
||||
)
|
||||
for required_state_id, required_state in rows:
|
||||
required_state_to_id[required_state] = required_state_id
|
||||
|
||||
room_to_state_ids: Dict[str, int] = {}
|
||||
unique_required_state: Dict[str, List[str]] = {}
|
||||
for room_id, room_state in per_connection_state.room_configs.items():
|
||||
serialized_state = json_encoder.encode(
|
||||
# We store the required state as a sorted list of event type /
|
||||
# state key tuples.
|
||||
sorted(
|
||||
(event_type, state_key)
|
||||
for event_type, state_keys in room_state.required_state_map.items()
|
||||
for state_key in state_keys
|
||||
)
|
||||
)
|
||||
|
||||
existing_state_id = required_state_to_id.get(serialized_state)
|
||||
if existing_state_id is not None:
|
||||
room_to_state_ids[room_id] = existing_state_id
|
||||
else:
|
||||
unique_required_state.setdefault(serialized_state, []).append(room_id)
|
||||
|
||||
# Insert any new `required_state` json we haven't previously seen.
|
||||
for serialized_required_state, room_ids in unique_required_state.items():
|
||||
(required_state_id,) = self.db_pool.simple_insert_returning_txn(
|
||||
txn,
|
||||
table="sliding_sync_connection_required_state",
|
||||
values={
|
||||
"connection_key": connection_key,
|
||||
"required_state": serialized_required_state,
|
||||
},
|
||||
returning=("required_state_id",),
|
||||
)
|
||||
for room_id in room_ids:
|
||||
room_to_state_ids[room_id] = required_state_id
|
||||
|
||||
# Copy over state from the previous connection position (we'll overwrite
|
||||
# these rows with any changes).
|
||||
if previous_connection_position is not None:
|
||||
sql = """
|
||||
INSERT INTO sliding_sync_connection_streams
|
||||
(connection_position, stream, room_id, room_status, last_position)
|
||||
SELECT ?, stream, room_id, room_status, last_position
|
||||
FROM sliding_sync_connection_streams
|
||||
WHERE connection_position = ?
|
||||
"""
|
||||
txn.execute(sql, (connection_position, previous_connection_position))
|
||||
|
||||
sql = """
|
||||
INSERT INTO sliding_sync_connection_room_configs
|
||||
(connection_position, room_id, timeline_limit, required_state_id)
|
||||
SELECT ?, room_id, timeline_limit, required_state_id
|
||||
FROM sliding_sync_connection_room_configs
|
||||
WHERE connection_position = ?
|
||||
"""
|
||||
txn.execute(sql, (connection_position, previous_connection_position))
|
||||
|
||||
# We now upsert the changes to the various streams.
|
||||
key_values = []
|
||||
value_values = []
|
||||
for room_id, have_sent_room in per_connection_state.rooms._statuses.items():
|
||||
key_values.append((connection_position, "rooms", room_id))
|
||||
value_values.append(
|
||||
(have_sent_room.status.value, have_sent_room.last_token)
|
||||
)
|
||||
|
||||
for room_id, have_sent_room in per_connection_state.receipts._statuses.items():
|
||||
key_values.append((connection_position, "receipts", room_id))
|
||||
value_values.append(
|
||||
(have_sent_room.status.value, have_sent_room.last_token)
|
||||
)
|
||||
|
||||
self.db_pool.simple_upsert_many_txn(
|
||||
txn,
|
||||
table="sliding_sync_connection_streams",
|
||||
key_names=(
|
||||
"connection_position",
|
||||
"stream",
|
||||
"room_id",
|
||||
),
|
||||
key_values=key_values,
|
||||
value_names=(
|
||||
"room_status",
|
||||
"last_position",
|
||||
),
|
||||
value_values=value_values,
|
||||
)
|
||||
|
||||
# ... and upsert changes to the room configs.
|
||||
keys = []
|
||||
values = []
|
||||
for room_id, room_config in per_connection_state.room_configs.items():
|
||||
keys.append((connection_position, room_id))
|
||||
values.append((room_config.timeline_limit, room_to_state_ids[room_id]))
|
||||
|
||||
self.db_pool.simple_upsert_many_txn(
|
||||
txn,
|
||||
table="sliding_sync_connection_room_configs",
|
||||
key_names=(
|
||||
"connection_position",
|
||||
"room_id",
|
||||
),
|
||||
key_values=keys,
|
||||
value_names=(
|
||||
"timeline_limit",
|
||||
"required_state_id",
|
||||
),
|
||||
value_values=values,
|
||||
)
|
||||
|
||||
return connection_position
|
||||
|
||||
@cached(iterable=True, max_entries=100000)
|
||||
async def get_per_connection_state(
|
||||
self, user_id: str, device_id: str, conn_id: str, connection_position: int
|
||||
) -> "PerConnectionState":
|
||||
"""Get the per-connection state for the given connection position."""
|
||||
|
||||
per_connection_state_db = await self.db_pool.runInteraction(
|
||||
"get_per_connection_state",
|
||||
self._get_per_connection_state_txn,
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
conn_id=conn_id,
|
||||
connection_position=connection_position,
|
||||
)
|
||||
store = cast("DataStore", self)
|
||||
return await per_connection_state_db.to_state(store)
|
||||
|
||||
def _get_per_connection_state_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
conn_id: str,
|
||||
connection_position: int,
|
||||
) -> "PerConnectionStateDB":
|
||||
# The `previous_connection_position` is a user-supplied value, so we
|
||||
# need to make sure that the one they supplied is actually theirs.
|
||||
sql = """
|
||||
SELECT connection_key
|
||||
FROM sliding_sync_connection_positions
|
||||
INNER JOIN sliding_sync_connections USING (connection_key)
|
||||
WHERE
|
||||
connection_position = ?
|
||||
AND user_id = ? AND device_id = ? AND conn_id = ?
|
||||
"""
|
||||
txn.execute(sql, (connection_position, user_id, device_id, conn_id))
|
||||
row = txn.fetchone()
|
||||
if row is None:
|
||||
raise SlidingSyncUnknownPosition()
|
||||
|
||||
(connection_key,) = row
|
||||
|
||||
# Now that we have seen the client has received and used the connection
|
||||
# position, we can delete all the other connection positions.
|
||||
sql = """
|
||||
DELETE FROM sliding_sync_connection_positions
|
||||
WHERE connection_key = ? AND connection_position != ?
|
||||
"""
|
||||
txn.execute(sql, (connection_key, connection_position))
|
||||
|
||||
# Fetch and create a mapping from required state ID to the actual
|
||||
# required state for the connection.
|
||||
rows = self.db_pool.simple_select_list_txn(
|
||||
txn,
|
||||
table="sliding_sync_connection_required_state",
|
||||
keyvalues={"connection_key": connection_key},
|
||||
retcols=(
|
||||
"required_state_id",
|
||||
"required_state",
|
||||
),
|
||||
)
|
||||
|
||||
required_state_map: Dict[int, Dict[str, Set[str]]] = {}
|
||||
for row in rows:
|
||||
state = required_state_map[row[0]] = {}
|
||||
for event_type, state_keys in db_to_json(row[1]):
|
||||
state[event_type] = set(state_keys)
|
||||
|
||||
# Get all the room configs, looking up the required state from the map
|
||||
# above.
|
||||
room_config_rows = self.db_pool.simple_select_list_txn(
|
||||
txn,
|
||||
table="sliding_sync_connection_room_configs",
|
||||
keyvalues={"connection_position": connection_position},
|
||||
retcols=(
|
||||
"room_id",
|
||||
"timeline_limit",
|
||||
"required_state_id",
|
||||
),
|
||||
)
|
||||
|
||||
room_configs: Dict[str, RoomSyncConfig] = {}
|
||||
for (
|
||||
room_id,
|
||||
timeline_limit,
|
||||
required_state_id,
|
||||
) in room_config_rows:
|
||||
room_configs[room_id] = RoomSyncConfig(
|
||||
timeline_limit=timeline_limit,
|
||||
required_state_map=required_state_map[required_state_id],
|
||||
)
|
||||
|
||||
# Now look up the per-room stream data.
|
||||
rooms: Dict[str, HaveSentRoom[str]] = {}
|
||||
receipts: Dict[str, HaveSentRoom[str]] = {}
|
||||
|
||||
receipt_rows = self.db_pool.simple_select_list_txn(
|
||||
txn,
|
||||
table="sliding_sync_connection_streams",
|
||||
keyvalues={"connection_position": connection_position},
|
||||
retcols=(
|
||||
"stream",
|
||||
"room_id",
|
||||
"room_status",
|
||||
"last_position",
|
||||
),
|
||||
)
|
||||
for stream, room_id, room_status, last_position in receipt_rows:
|
||||
have_sent_room: HaveSentRoom[str] = HaveSentRoom(
|
||||
status=HaveSentRoomFlag(room_status), last_token=last_position
|
||||
)
|
||||
if stream == "rooms":
|
||||
rooms[room_id] = have_sent_room
|
||||
elif stream == "receipts":
|
||||
receipts[room_id] = have_sent_room
|
||||
|
||||
return PerConnectionStateDB(
|
||||
rooms=RoomStatusMap(rooms),
|
||||
receipts=RoomStatusMap(receipts),
|
||||
room_configs=room_configs,
|
||||
)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True, frozen=True)
|
||||
class PerConnectionStateDB:
|
||||
"""An equivalent to `PerConnectionState` that holds data in a format stored
|
||||
in the DB.
|
||||
|
||||
The principle difference is that the tokens for the different streams are
|
||||
serialized to strings.
|
||||
|
||||
When persisting this *only* contains updates to the state.
|
||||
"""
|
||||
|
||||
rooms: "RoomStatusMap[str]"
|
||||
receipts: "RoomStatusMap[str]"
|
||||
|
||||
room_configs: Mapping[str, "RoomSyncConfig"]
|
||||
|
||||
@staticmethod
|
||||
async def from_state(
|
||||
per_connection_state: "MutablePerConnectionState", store: "DataStore"
|
||||
) -> "PerConnectionStateDB":
|
||||
"""Convert from a standard `PerConnectionState`"""
|
||||
rooms = {
|
||||
room_id: HaveSentRoom(
|
||||
status=status.status,
|
||||
last_token=(
|
||||
await status.last_token.to_string(store)
|
||||
if status.last_token is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
for room_id, status in per_connection_state.rooms.get_updates().items()
|
||||
}
|
||||
|
||||
receipts = {
|
||||
room_id: HaveSentRoom(
|
||||
status=status.status,
|
||||
last_token=(
|
||||
await status.last_token.to_string(store)
|
||||
if status.last_token is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
for room_id, status in per_connection_state.receipts.get_updates().items()
|
||||
}
|
||||
|
||||
log_kv(
|
||||
{
|
||||
"rooms": rooms,
|
||||
"receipts": receipts,
|
||||
"room_configs": per_connection_state.room_configs.maps[0],
|
||||
}
|
||||
)
|
||||
|
||||
return PerConnectionStateDB(
|
||||
rooms=RoomStatusMap(rooms),
|
||||
receipts=RoomStatusMap(receipts),
|
||||
room_configs=per_connection_state.room_configs.maps[0],
|
||||
)
|
||||
|
||||
async def to_state(self, store: "DataStore") -> "PerConnectionState":
|
||||
"""Convert into a standard `PerConnectionState`"""
|
||||
rooms = {
|
||||
room_id: HaveSentRoom(
|
||||
status=status.status,
|
||||
last_token=(
|
||||
await RoomStreamToken.parse(store, status.last_token)
|
||||
if status.last_token is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
for room_id, status in self.rooms._statuses.items()
|
||||
}
|
||||
|
||||
receipts = {
|
||||
room_id: HaveSentRoom(
|
||||
status=status.status,
|
||||
last_token=(
|
||||
await MultiWriterStreamToken.parse(store, status.last_token)
|
||||
if status.last_token is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
for room_id, status in self.receipts._statuses.items()
|
||||
}
|
||||
|
||||
return PerConnectionState(
|
||||
rooms=RoomStatusMap(rooms),
|
||||
receipts=RoomStatusMap(receipts),
|
||||
room_configs=self.room_configs,
|
||||
)
|
|
@ -19,7 +19,7 @@
|
|||
#
|
||||
#
|
||||
|
||||
SCHEMA_VERSION = 86 # remember to update the list below when updating
|
||||
SCHEMA_VERSION = 87 # remember to update the list below when updating
|
||||
"""Represents the expectations made by the codebase about the database schema
|
||||
|
||||
This should be incremented whenever the codebase changes its requirements on the
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2024 New Vector, Ltd
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
|
||||
-- Table to track active sliding sync connections.
|
||||
--
|
||||
-- A new connection will be created for every sliding sync request without a
|
||||
-- `since` token for a given `conn_id` for a device.#
|
||||
--
|
||||
-- Once a new connection is created and used we delete all other connections for
|
||||
-- the `conn_id`.
|
||||
CREATE TABLE sliding_sync_connections(
|
||||
connection_key $%AUTO_INCREMENT_PRIMARY_KEY%$,
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
conn_id TEXT NOT NULL,
|
||||
created_ts BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX sliding_sync_connections_idx ON sliding_sync_connections(user_id, device_id, conn_id);
|
||||
|
||||
-- We track per-connection state by associating changes to the state with
|
||||
-- connection positions. This ensures that we correctly track state even if we
|
||||
-- see retries of requests.
|
||||
--
|
||||
-- If the client starts a "new" connection (by not specifying a since token),
|
||||
-- we'll clear out the other connections (to ensure that we don't end up with
|
||||
-- lots of connection keys).
|
||||
CREATE TABLE sliding_sync_connection_positions(
|
||||
connection_position $%AUTO_INCREMENT_PRIMARY_KEY%$,
|
||||
connection_key BIGINT NOT NULL REFERENCES sliding_sync_connections(connection_key) ON DELETE CASCADE,
|
||||
created_ts BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX sliding_sync_connection_positions_key ON sliding_sync_connection_positions(connection_key);
|
||||
|
||||
|
||||
-- To save space we deduplicate the `required_state` json by assigning IDs to
|
||||
-- different values.
|
||||
CREATE TABLE sliding_sync_connection_required_state(
|
||||
required_state_id $%AUTO_INCREMENT_PRIMARY_KEY%$,
|
||||
connection_key BIGINT NOT NULL REFERENCES sliding_sync_connections(connection_key) ON DELETE CASCADE,
|
||||
required_state TEXT NOT NULL -- We store this as a json list of event type / state key tuples.
|
||||
);
|
||||
|
||||
CREATE INDEX sliding_sync_connection_required_state_conn_pos ON sliding_sync_connections(connection_key);
|
||||
|
||||
|
||||
-- Stores the room configs we have seen for rooms in a connection.
|
||||
CREATE TABLE sliding_sync_connection_room_configs(
|
||||
connection_position BIGINT NOT NULL REFERENCES sliding_sync_connection_positions(connection_position) ON DELETE CASCADE,
|
||||
room_id TEXT NOT NULL,
|
||||
timeline_limit BIGINT NOT NULL,
|
||||
required_state_id BIGINT NOT NULL REFERENCES sliding_sync_connection_required_state(required_state_id)
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX sliding_sync_connection_room_configs_idx ON sliding_sync_connection_room_configs(connection_position, room_id);
|
||||
|
||||
-- Stores what data we have sent for given streams down given connections.
|
||||
CREATE TABLE sliding_sync_connection_streams(
|
||||
connection_position BIGINT NOT NULL REFERENCES sliding_sync_connection_positions(connection_position) ON DELETE CASCADE,
|
||||
stream TEXT NOT NULL, -- e.g. "events" or "receipts"
|
||||
room_id TEXT NOT NULL,
|
||||
room_status TEXT NOT NULL, -- "live" or "previously", i.e. the `HaveSentRoomFlag` value
|
||||
last_position TEXT -- For "previously" the token for the stream we have sent up to.
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX sliding_sync_connection_streams_idx ON sliding_sync_connection_streams(connection_position, room_id, stream);
|
|
@ -741,6 +741,9 @@ class RoomStatusMap(Generic[T]):
|
|||
|
||||
return RoomStatusMap(statuses=dict(self._statuses))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._statuses)
|
||||
|
||||
|
||||
class MutableRoomStatusMap(RoomStatusMap[T]):
|
||||
"""A mutable version of `RoomStatusMap`"""
|
||||
|
@ -842,6 +845,9 @@ class PerConnectionState:
|
|||
room_configs=dict(self.room_configs),
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.rooms) + len(self.receipts) + len(self.room_configs)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class MutablePerConnectionState(PerConnectionState):
|
||||
|
|
|
@ -191,8 +191,14 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
|
|||
}
|
||||
_, from_token = self.do_sync(sync_body, tok=user1_tok)
|
||||
|
||||
# Reset the in-memory cache
|
||||
self.hs.get_sliding_sync_handler().connection_store._connections.clear()
|
||||
# Reset the positions
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_delete(
|
||||
table="sliding_sync_connections",
|
||||
keyvalues={"user_id": user1_id},
|
||||
desc="clear_cache",
|
||||
)
|
||||
)
|
||||
|
||||
# Make the Sliding Sync request
|
||||
channel = self.make_request(
|
||||
|
|
Loading…
Reference in a new issue