Fixup slave stores

This commit is contained in:
Erik Johnston 2019-03-04 18:03:29 +00:00
parent aba5eeabd5
commit a84b8d56c2
6 changed files with 572 additions and 577 deletions

View file

@ -13,15 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage import DataStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage.deviceinbox import DeviceInboxWorkerStore
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore, __func__
from ._slaved_id_tracker import SlavedIdTracker
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
class SlavedDeviceInboxStore(BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker( self._device_inbox_id_gen = SlavedIdTracker(
@ -43,12 +42,6 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
expiry_ms=30 * 60 * 1000, expiry_ms=30 * 60 * 1000,
) )
get_to_device_stream_token = __func__(DataStore.get_to_device_stream_token)
get_new_messages_for_device = __func__(DataStore.get_new_messages_for_device)
get_new_device_msgs_for_remote = __func__(DataStore.get_new_device_msgs_for_remote)
delete_messages_for_device = __func__(DataStore.delete_messages_for_device)
delete_device_msgs_for_remote = __func__(DataStore.delete_device_msgs_for_remote)
def stream_positions(self): def stream_positions(self):
result = super(SlavedDeviceInboxStore, self).stream_positions() result = super(SlavedDeviceInboxStore, self).stream_positions()
result["to_device"] = self._device_inbox_id_gen.get_current_token() result["to_device"] = self._device_inbox_id_gen.get_current_token()

View file

@ -13,15 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage import DataStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.storage.end_to_end_keys import EndToEndKeyStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage.devices import DeviceWorkerStore
from synapse.storage.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore, __func__
from ._slaved_id_tracker import SlavedIdTracker
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
class SlavedDeviceStore(BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SlavedDeviceStore, self).__init__(db_conn, hs) super(SlavedDeviceStore, self).__init__(db_conn, hs)
@ -38,17 +37,6 @@ class SlavedDeviceStore(BaseSlavedStore):
"DeviceListFederationStreamChangeCache", device_list_max, "DeviceListFederationStreamChangeCache", device_list_max,
) )
get_device_stream_token = __func__(DataStore.get_device_stream_token)
get_user_whose_devices_changed = __func__(DataStore.get_user_whose_devices_changed)
get_devices_by_remote = __func__(DataStore.get_devices_by_remote)
_get_devices_by_remote_txn = __func__(DataStore._get_devices_by_remote_txn)
_get_e2e_device_keys_txn = __func__(DataStore._get_e2e_device_keys_txn)
mark_as_sent_devices_by_remote = __func__(DataStore.mark_as_sent_devices_by_remote)
_mark_as_sent_devices_by_remote_txn = (
__func__(DataStore._mark_as_sent_devices_by_remote_txn)
)
count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"]
def stream_positions(self): def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions() result = super(SlavedDeviceStore, self).stream_positions()
result["device_lists"] = self._device_list_id_gen.get_current_token() result["device_lists"] = self._device_list_id_gen.get_current_token()
@ -58,14 +46,23 @@ class SlavedDeviceStore(BaseSlavedStore):
if stream_name == "device_lists": if stream_name == "device_lists":
self._device_list_id_gen.advance(token) self._device_list_id_gen.advance(token)
for row in rows: for row in rows:
self._device_list_stream_cache.entity_has_changed( self._invalidate_caches_for_devices(
row.user_id, token token, row.user_id, row.destination,
)
if row.destination:
self._device_list_federation_stream_cache.entity_has_changed(
row.destination, token
) )
return super(SlavedDeviceStore, self).process_replication_rows( return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows stream_name, token, rows
) )
def _invalidate_caches_for_devices(self, token, user_id, destination):
self._device_list_stream_cache.entity_has_changed(
user_id, token
)
if destination:
self._device_list_federation_stream_cache.entity_has_changed(
destination, token
)
self._get_cached_devices_for_user.invalidate((user_id,))
self._get_cached_user_device.invalidate_many((user_id,))
self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))

View file

@ -20,7 +20,7 @@ from ._slaved_id_tracker import SlavedIdTracker
from .events import SlavedEventStore from .events import SlavedEventStore
class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore): class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
self._push_rules_stream_id_gen = SlavedIdTracker( self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id", db_conn, "push_rules_stream", "stream_id",

View file

@ -19,14 +19,174 @@ from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from .background_updates import BackgroundUpdateStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DeviceInboxStore(BackgroundUpdateStore): class DeviceInboxWorkerStore(SQLBaseStore):
def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token()
def get_new_messages_for_device(
self, user_id, device_id, last_stream_id, current_stream_id, limit=100
):
"""
Args:
user_id(str): The recipient user_id.
device_id(str): The recipient device_id.
current_stream_id(int): The current position of the to device
message stream.
Returns:
Deferred ([dict], int): List of messages for the device and where
in the stream the messages got to.
"""
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_stream_id
)
if not has_changed:
return defer.succeed(([], current_stream_id))
def get_new_messages_for_device_txn(txn):
sql = (
"SELECT stream_id, message_json FROM device_inbox"
" WHERE user_id = ? AND device_id = ?"
" AND ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(sql, (
user_id, device_id, last_stream_id, current_stream_id, limit
))
messages = []
for row in txn:
stream_pos = row[0]
messages.append(json.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
return self.runInteraction(
"get_new_messages_for_device", get_new_messages_for_device_txn,
)
@defer.inlineCallbacks
def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
"""
Args:
user_id(str): The recipient user_id.
device_id(str): The recipient device_id.
up_to_stream_id(int): Where to delete messages up to.
Returns:
A deferred that resolves to the number of messages deleted.
"""
# If we have cached the last stream id we've deleted up to, we can
# check if there is likely to be anything that needs deleting
last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), None
)
if last_deleted_stream_id:
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_deleted_stream_id
)
if not has_changed:
defer.returnValue(0)
def delete_messages_for_device_txn(txn):
sql = (
"DELETE FROM device_inbox"
" WHERE user_id = ? AND device_id = ?"
" AND stream_id <= ?"
)
txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount
count = yield self.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn
)
# Update the cache, ensuring that we only ever increase the value
last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), 0
)
self._last_device_delete_cache[(user_id, device_id)] = max(
last_deleted_stream_id, up_to_stream_id
)
defer.returnValue(count)
def get_new_device_msgs_for_remote(
self, destination, last_stream_id, current_stream_id, limit=100
):
"""
Args:
destination(str): The name of the remote server.
last_stream_id(int|long): The last position of the device message stream
that the server sent up to.
current_stream_id(int|long): The current position of the device
message stream.
Returns:
Deferred ([dict], int|long): List of messages for the device and where
in the stream the messages got to.
"""
has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
destination, last_stream_id
)
if not has_changed or last_stream_id == current_stream_id:
return defer.succeed(([], current_stream_id))
def get_new_messages_for_remote_destination_txn(txn):
sql = (
"SELECT stream_id, messages_json FROM device_federation_outbox"
" WHERE destination = ?"
" AND ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(sql, (
destination, last_stream_id, current_stream_id, limit
))
messages = []
for row in txn:
stream_pos = row[0]
messages.append(json.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
return self.runInteraction(
"get_new_device_msgs_for_remote",
get_new_messages_for_remote_destination_txn,
)
def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
"""Used to delete messages when the remote destination acknowledges
their receipt.
Args:
destination(str): The destination server_name
up_to_stream_id(int): Where to delete messages up to.
Returns:
A deferred that resolves when the messages have been deleted.
"""
def delete_messages_for_remote_destination_txn(txn):
sql = (
"DELETE FROM device_federation_outbox"
" WHERE destination = ?"
" AND stream_id <= ?"
)
txn.execute(sql, (destination, up_to_stream_id))
return self.runInteraction(
"delete_device_msgs_for_remote",
delete_messages_for_remote_destination_txn
)
class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
@ -220,93 +380,6 @@ class DeviceInboxStore(BackgroundUpdateStore):
txn.executemany(sql, rows) txn.executemany(sql, rows)
def get_new_messages_for_device(
self, user_id, device_id, last_stream_id, current_stream_id, limit=100
):
"""
Args:
user_id(str): The recipient user_id.
device_id(str): The recipient device_id.
current_stream_id(int): The current position of the to device
message stream.
Returns:
Deferred ([dict], int): List of messages for the device and where
in the stream the messages got to.
"""
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_stream_id
)
if not has_changed:
return defer.succeed(([], current_stream_id))
def get_new_messages_for_device_txn(txn):
sql = (
"SELECT stream_id, message_json FROM device_inbox"
" WHERE user_id = ? AND device_id = ?"
" AND ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(sql, (
user_id, device_id, last_stream_id, current_stream_id, limit
))
messages = []
for row in txn:
stream_pos = row[0]
messages.append(json.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
return self.runInteraction(
"get_new_messages_for_device", get_new_messages_for_device_txn,
)
@defer.inlineCallbacks
def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
"""
Args:
user_id(str): The recipient user_id.
device_id(str): The recipient device_id.
up_to_stream_id(int): Where to delete messages up to.
Returns:
A deferred that resolves to the number of messages deleted.
"""
# If we have cached the last stream id we've deleted up to, we can
# check if there is likely to be anything that needs deleting
last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), None
)
if last_deleted_stream_id:
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_deleted_stream_id
)
if not has_changed:
defer.returnValue(0)
def delete_messages_for_device_txn(txn):
sql = (
"DELETE FROM device_inbox"
" WHERE user_id = ? AND device_id = ?"
" AND stream_id <= ?"
)
txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount
count = yield self.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn
)
# Update the cache, ensuring that we only ever increase the value
last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), 0
)
self._last_device_delete_cache[(user_id, device_id)] = max(
last_deleted_stream_id, up_to_stream_id
)
defer.returnValue(count)
def get_all_new_device_messages(self, last_pos, current_pos, limit): def get_all_new_device_messages(self, last_pos, current_pos, limit):
""" """
Args: Args:
@ -351,77 +424,6 @@ class DeviceInboxStore(BackgroundUpdateStore):
"get_all_new_device_messages", get_all_new_device_messages_txn "get_all_new_device_messages", get_all_new_device_messages_txn
) )
def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token()
def get_new_device_msgs_for_remote(
self, destination, last_stream_id, current_stream_id, limit=100
):
"""
Args:
destination(str): The name of the remote server.
last_stream_id(int|long): The last position of the device message stream
that the server sent up to.
current_stream_id(int|long): The current position of the device
message stream.
Returns:
Deferred ([dict], int|long): List of messages for the device and where
in the stream the messages got to.
"""
has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
destination, last_stream_id
)
if not has_changed or last_stream_id == current_stream_id:
return defer.succeed(([], current_stream_id))
def get_new_messages_for_remote_destination_txn(txn):
sql = (
"SELECT stream_id, messages_json FROM device_federation_outbox"
" WHERE destination = ?"
" AND ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(sql, (
destination, last_stream_id, current_stream_id, limit
))
messages = []
for row in txn:
stream_pos = row[0]
messages.append(json.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
return self.runInteraction(
"get_new_device_msgs_for_remote",
get_new_messages_for_remote_destination_txn,
)
def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
"""Used to delete messages when the remote destination acknowledges
their receipt.
Args:
destination(str): The destination server_name
up_to_stream_id(int): Where to delete messages up to.
Returns:
A deferred that resolves when the messages have been deleted.
"""
def delete_messages_for_remote_destination_txn(txn):
sql = (
"DELETE FROM device_federation_outbox"
" WHERE destination = ?"
" AND stream_id <= ?"
)
txn.execute(sql, (destination, up_to_stream_id))
return self.runInteraction(
"delete_device_msgs_for_remote",
delete_messages_for_remote_destination_txn
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _background_drop_index_device_inbox(self, progress, batch_size): def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn): def reindex_txn(conn):

View file

@ -22,11 +22,10 @@ from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import Cache, SQLBaseStore, db_to_json
from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from ._base import Cache, db_to_json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
@ -34,7 +33,343 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
) )
class DeviceStore(BackgroundUpdateStore): class DeviceWorkerStore(SQLBaseStore):
def get_device(self, user_id, device_id):
"""Retrieve a device.
Args:
user_id (str): The ID of the user which owns the device
device_id (str): The ID of the device to retrieve
Returns:
defer.Deferred for a dict containing the device information
Raises:
StoreError: if the device is not found
"""
return self._simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
)
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
"""Retrieve all of a user's registered devices.
Args:
user_id (str):
Returns:
defer.Deferred: resolves to a dict from device_id to a dict
containing "device_id", "user_id" and "display_name" for each
device.
"""
devices = yield self._simple_select_list(
table="devices",
keyvalues={"user_id": user_id},
retcols=("user_id", "device_id", "display_name"),
desc="get_devices_by_user"
)
defer.returnValue({d["device_id"]: d for d in devices})
def get_devices_by_remote(self, destination, from_stream_id):
"""Get stream of updates to send to remote servers
Returns:
(int, list[dict]): current stream id and list of updates
"""
now_stream_id = self._device_list_id_gen.get_current_token()
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
)
if not has_changed:
return (now_stream_id, [])
return self.runInteraction(
"get_devices_by_remote", self._get_devices_by_remote_txn,
destination, from_stream_id, now_stream_id,
)
def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
now_stream_id):
sql = """
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
GROUP BY user_id, device_id
LIMIT 20
"""
txn.execute(
sql, (destination, from_stream_id, now_stream_id, False)
)
# maps (user_id, device_id) -> stream_id
query_map = {(r[0], r[1]): r[2] for r in txn}
if not query_map:
return (now_stream_id, [])
if len(query_map) >= 20:
now_stream_id = max(stream_id for stream_id in itervalues(query_map))
devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
)
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ? AND stream_id <= ?
"""
results = []
for user_id, user_devices in iteritems(devices):
# The prev_id for the first row is always the last row before
# `from_stream_id`
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
rows = txn.fetchall()
prev_id = rows[0][0]
for device_id, device in iteritems(user_devices):
stream_id = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
"device_id": device_id,
"prev_id": [prev_id] if prev_id else [],
"stream_id": stream_id,
}
prev_id = stream_id
if device is not None:
key_json = device.get("key_json", None)
if key_json:
result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
else:
result["deleted"] = True
results.append(result)
return (now_stream_id, results)
def mark_as_sent_devices_by_remote(self, destination, stream_id):
"""Mark that updates have successfully been sent to the destination.
"""
return self.runInteraction(
"mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
destination, stream_id,
)
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
# We update the device_lists_outbound_last_success with the successfully
# poked users. We do the join to see which users need to be inserted and
# which updated.
sql = """
SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
FROM device_lists_outbound_pokes as o
LEFT JOIN device_lists_outbound_last_success as s
USING (destination, user_id)
WHERE destination = ? AND o.stream_id <= ?
GROUP BY user_id
"""
txn.execute(sql, (destination, stream_id,))
rows = txn.fetchall()
sql = """
UPDATE device_lists_outbound_last_success
SET stream_id = ?
WHERE destination = ? AND user_id = ?
"""
txn.executemany(
sql, ((row[1], destination, row[0],) for row in rows if row[2])
)
sql = """
INSERT INTO device_lists_outbound_last_success
(destination, user_id, stream_id) VALUES (?, ?, ?)
"""
txn.executemany(
sql, ((destination, row[0], row[1],) for row in rows if not row[2])
)
# Delete all sent outbound pokes
sql = """
DELETE FROM device_lists_outbound_pokes
WHERE destination = ? AND stream_id <= ?
"""
txn.execute(sql, (destination, stream_id,))
def get_device_stream_token(self):
return self._device_list_id_gen.get_current_token()
@defer.inlineCallbacks
def get_user_devices_from_cache(self, query_list):
"""Get the devices (and keys if any) for remote users from the cache.
Args:
query_list(list): List of (user_id, device_ids), if device_ids is
falsey then return all device ids for that user.
Returns:
(user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
a set of user_ids and results_map is a mapping of
user_id -> device_id -> device_info
"""
user_ids = set(user_id for user_id, _ in query_list)
user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
user_ids_in_cache = set(
user_id for user_id, stream_id in user_map.items() if stream_id
)
user_ids_not_in_cache = user_ids - user_ids_in_cache
results = {}
for user_id, device_id in query_list:
if user_id not in user_ids_in_cache:
continue
if device_id:
device = yield self._get_cached_user_device(user_id, device_id)
results.setdefault(user_id, {})[device_id] = device
else:
results[user_id] = yield self._get_cached_devices_for_user(user_id)
defer.returnValue((user_ids_not_in_cache, results))
@cachedInlineCallbacks(num_args=2, tree=True)
def _get_cached_user_device(self, user_id, device_id):
content = yield self._simple_select_one_onecol(
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
retcol="content",
desc="_get_cached_user_device",
)
defer.returnValue(db_to_json(content))
@cachedInlineCallbacks()
def _get_cached_devices_for_user(self, user_id):
devices = yield self._simple_select_list(
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
},
retcols=("device_id", "content"),
desc="_get_cached_devices_for_user",
)
defer.returnValue({
device["device_id"]: db_to_json(device["content"])
for device in devices
})
def get_devices_with_keys_by_user(self, user_id):
"""Get all devices (with any device keys) for a user
Returns:
(stream_id, devices)
"""
return self.runInteraction(
"get_devices_with_keys_by_user",
self._get_devices_with_keys_by_user_txn, user_id,
)
def _get_devices_with_keys_by_user_txn(self, txn, user_id):
now_stream_id = self._device_list_id_gen.get_current_token()
devices = self._get_e2e_device_keys_txn(
txn, [(user_id, None)], include_all_devices=True
)
if devices:
user_devices = devices[user_id]
results = []
for device_id, device in iteritems(user_devices):
result = {
"device_id": device_id,
}
key_json = device.get("key_json", None)
if key_json:
result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
results.append(result)
return now_stream_id, results
return now_stream_id, []
@defer.inlineCallbacks
def get_user_whose_devices_changed(self, from_key):
"""Get set of users whose devices have changed since `from_key`.
"""
from_key = int(from_key)
changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
if changed is not None:
defer.returnValue(set(changed))
sql = """
SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
"""
rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
defer.returnValue(set(row[0] for row in rows))
def get_all_device_list_changes_for_remotes(self, from_key, to_key):
"""Return a list of `(stream_id, user_id, destination)` which is the
combined list of changes to devices, and which destinations need to be
poked. `destination` may be None if no destinations need to be poked.
"""
# We do a group by here as there can be a large number of duplicate
# entries, since we throw away device IDs.
sql = """
SELECT MAX(stream_id) AS stream_id, user_id, destination
FROM device_lists_stream
LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id, destination
"""
return self._execute(
"get_all_device_list_changes_for_remotes", None,
sql, from_key, to_key
)
@cached(max_entries=10000)
def get_device_list_last_stream_id_for_remote(self, user_id):
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
return self._simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
desc="get_device_list_last_stream_id_for_remote",
allow_none=True,
)
@cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids", inlineCallbacks=True)
def get_device_list_last_stream_id_for_remotes(self, user_ids):
rows = yield self._simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
retcols=("user_id", "stream_id",),
desc="get_device_list_last_stream_id_for_remotes",
)
results = {user_id: None for user_id in user_ids}
results.update({
row["user_id"]: row["stream_id"] for row in rows
})
defer.returnValue(results)
class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(DeviceStore, self).__init__(db_conn, hs) super(DeviceStore, self).__init__(db_conn, hs)
@ -121,24 +456,6 @@ class DeviceStore(BackgroundUpdateStore):
initial_device_display_name, e) initial_device_display_name, e)
raise StoreError(500, "Problem storing device.") raise StoreError(500, "Problem storing device.")
def get_device(self, user_id, device_id):
"""Retrieve a device.
Args:
user_id (str): The ID of the user which owns the device
device_id (str): The ID of the device to retrieve
Returns:
defer.Deferred for a dict containing the device information
Raises:
StoreError: if the device is not found
"""
return self._simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_device(self, user_id, device_id): def delete_device(self, user_id, device_id):
"""Delete a device. """Delete a device.
@ -202,57 +519,6 @@ class DeviceStore(BackgroundUpdateStore):
desc="update_device", desc="update_device",
) )
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
"""Retrieve all of a user's registered devices.
Args:
user_id (str):
Returns:
defer.Deferred: resolves to a dict from device_id to a dict
containing "device_id", "user_id" and "display_name" for each
device.
"""
devices = yield self._simple_select_list(
table="devices",
keyvalues={"user_id": user_id},
retcols=("user_id", "device_id", "display_name"),
desc="get_devices_by_user"
)
defer.returnValue({d["device_id"]: d for d in devices})
@cached(max_entries=10000)
def get_device_list_last_stream_id_for_remote(self, user_id):
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
return self._simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
desc="get_device_list_remote_extremity",
allow_none=True,
)
@cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids", inlineCallbacks=True)
def get_device_list_last_stream_id_for_remotes(self, user_ids):
rows = yield self._simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
retcols=("user_id", "stream_id",),
desc="get_user_devices_from_cache",
)
results = {user_id: None for user_id in user_ids}
results.update({
row["user_id"]: row["stream_id"] for row in rows
})
defer.returnValue(results)
@defer.inlineCallbacks @defer.inlineCallbacks
def mark_remote_user_device_list_as_unsubscribed(self, user_id): def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user. """Mark that we no longer track device lists for remote user.
@ -405,268 +671,6 @@ class DeviceStore(BackgroundUpdateStore):
lock=False, lock=False,
) )
def get_devices_by_remote(self, destination, from_stream_id):
"""Get stream of updates to send to remote servers
Returns:
(int, list[dict]): current stream id and list of updates
"""
now_stream_id = self._device_list_id_gen.get_current_token()
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
)
if not has_changed:
return (now_stream_id, [])
return self.runInteraction(
"get_devices_by_remote", self._get_devices_by_remote_txn,
destination, from_stream_id, now_stream_id,
)
def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
now_stream_id):
sql = """
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
GROUP BY user_id, device_id
LIMIT 20
"""
txn.execute(
sql, (destination, from_stream_id, now_stream_id, False)
)
# maps (user_id, device_id) -> stream_id
query_map = {(r[0], r[1]): r[2] for r in txn}
if not query_map:
return (now_stream_id, [])
if len(query_map) >= 20:
now_stream_id = max(stream_id for stream_id in itervalues(query_map))
devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
)
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ? AND stream_id <= ?
"""
results = []
for user_id, user_devices in iteritems(devices):
# The prev_id for the first row is always the last row before
# `from_stream_id`
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
rows = txn.fetchall()
prev_id = rows[0][0]
for device_id, device in iteritems(user_devices):
stream_id = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
"device_id": device_id,
"prev_id": [prev_id] if prev_id else [],
"stream_id": stream_id,
}
prev_id = stream_id
if device is not None:
key_json = device.get("key_json", None)
if key_json:
result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
else:
result["deleted"] = True
results.append(result)
return (now_stream_id, results)
@defer.inlineCallbacks
def get_user_devices_from_cache(self, query_list):
"""Get the devices (and keys if any) for remote users from the cache.
Args:
query_list(list): List of (user_id, device_ids), if device_ids is
falsey then return all device ids for that user.
Returns:
(user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
a set of user_ids and results_map is a mapping of
user_id -> device_id -> device_info
"""
user_ids = set(user_id for user_id, _ in query_list)
user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
user_ids_in_cache = set(
user_id for user_id, stream_id in user_map.items() if stream_id
)
user_ids_not_in_cache = user_ids - user_ids_in_cache
results = {}
for user_id, device_id in query_list:
if user_id not in user_ids_in_cache:
continue
if device_id:
device = yield self._get_cached_user_device(user_id, device_id)
results.setdefault(user_id, {})[device_id] = device
else:
results[user_id] = yield self._get_cached_devices_for_user(user_id)
defer.returnValue((user_ids_not_in_cache, results))
@cachedInlineCallbacks(num_args=2, tree=True)
def _get_cached_user_device(self, user_id, device_id):
content = yield self._simple_select_one_onecol(
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
retcol="content",
desc="_get_cached_user_device",
)
defer.returnValue(db_to_json(content))
@cachedInlineCallbacks()
def _get_cached_devices_for_user(self, user_id):
devices = yield self._simple_select_list(
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
},
retcols=("device_id", "content"),
desc="_get_cached_devices_for_user",
)
defer.returnValue({
device["device_id"]: db_to_json(device["content"])
for device in devices
})
def get_devices_with_keys_by_user(self, user_id):
"""Get all devices (with any device keys) for a user
Returns:
(stream_id, devices)
"""
return self.runInteraction(
"get_devices_with_keys_by_user",
self._get_devices_with_keys_by_user_txn, user_id,
)
def _get_devices_with_keys_by_user_txn(self, txn, user_id):
now_stream_id = self._device_list_id_gen.get_current_token()
devices = self._get_e2e_device_keys_txn(
txn, [(user_id, None)], include_all_devices=True
)
if devices:
user_devices = devices[user_id]
results = []
for device_id, device in iteritems(user_devices):
result = {
"device_id": device_id,
}
key_json = device.get("key_json", None)
if key_json:
result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
results.append(result)
return now_stream_id, results
return now_stream_id, []
def mark_as_sent_devices_by_remote(self, destination, stream_id):
"""Mark that updates have successfully been sent to the destination.
"""
return self.runInteraction(
"mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
destination, stream_id,
)
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
# We update the device_lists_outbound_last_success with the successfully
# poked users. We do the join to see which users need to be inserted and
# which updated.
sql = """
SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
FROM device_lists_outbound_pokes as o
LEFT JOIN device_lists_outbound_last_success as s
USING (destination, user_id)
WHERE destination = ? AND o.stream_id <= ?
GROUP BY user_id
"""
txn.execute(sql, (destination, stream_id,))
rows = txn.fetchall()
sql = """
UPDATE device_lists_outbound_last_success
SET stream_id = ?
WHERE destination = ? AND user_id = ?
"""
txn.executemany(
sql, ((row[1], destination, row[0],) for row in rows if row[2])
)
sql = """
INSERT INTO device_lists_outbound_last_success
(destination, user_id, stream_id) VALUES (?, ?, ?)
"""
txn.executemany(
sql, ((destination, row[0], row[1],) for row in rows if not row[2])
)
# Delete all sent outbound pokes
sql = """
DELETE FROM device_lists_outbound_pokes
WHERE destination = ? AND stream_id <= ?
"""
txn.execute(sql, (destination, stream_id,))
@defer.inlineCallbacks
def get_user_whose_devices_changed(self, from_key):
"""Get set of users whose devices have changed since `from_key`.
"""
from_key = int(from_key)
changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
if changed is not None:
defer.returnValue(set(changed))
sql = """
SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
"""
rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
defer.returnValue(set(row[0] for row in rows))
def get_all_device_list_changes_for_remotes(self, from_key, to_key):
"""Return a list of `(stream_id, user_id, destination)` which is the
combined list of changes to devices, and which destinations need to be
poked. `destination` may be None if no destinations need to be poked.
"""
# We do a group by here as there can be a large number of duplicate
# entries, since we throw away device IDs.
sql = """
SELECT MAX(stream_id) AS stream_id, user_id, destination
FROM device_lists_stream
LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id, destination
"""
return self._execute(
"get_all_device_list_changes_for_remotes", None,
sql, from_key, to_key
)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_device_change_to_streams(self, user_id, device_ids, hosts): def add_device_change_to_streams(self, user_id, device_ids, hosts):
"""Persist that a user's devices have been updated, and which hosts """Persist that a user's devices have been updated, and which hosts
@ -732,9 +736,6 @@ class DeviceStore(BackgroundUpdateStore):
] ]
) )
def get_device_stream_token(self):
return self._device_list_id_gen.get_current_token()
def _prune_old_outbound_device_pokes(self): def _prune_old_outbound_device_pokes(self):
"""Delete old entries out of the device_lists_outbound_pokes to ensure """Delete old entries out of the device_lists_outbound_pokes to ensure
that we don't fill up due to dead servers. We keep one entry per that we don't fill up due to dead servers. We keep one entry per

View file

@ -23,49 +23,7 @@ from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore, db_to_json from ._base import SQLBaseStore, db_to_json
class EndToEndKeyStore(SQLBaseStore): class EndToEndKeyWorkerStore(SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
"""
def _set_e2e_device_keys_txn(txn):
old_key_json = self._simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
retcol="key_json",
allow_none=True,
)
# In py3 we need old_key_json to match new_key_json type. The DB
# returns unicode while encode_canonical_json returns bytes.
new_key_json = encode_canonical_json(device_keys).decode("utf-8")
if old_key_json == new_key_json:
return False
self._simple_upsert_txn(
txn,
table="e2e_device_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
values={
"ts_added_ms": time_now,
"key_json": new_key_json,
}
)
return True
return self.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_e2e_device_keys( def get_e2e_device_keys(
self, query_list, include_all_devices=False, self, query_list, include_all_devices=False,
@ -238,6 +196,50 @@ class EndToEndKeyStore(SQLBaseStore):
"count_e2e_one_time_keys", _count_e2e_one_time_keys "count_e2e_one_time_keys", _count_e2e_one_time_keys
) )
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
"""
def _set_e2e_device_keys_txn(txn):
old_key_json = self._simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
retcol="key_json",
allow_none=True,
)
# In py3 we need old_key_json to match new_key_json type. The DB
# returns unicode while encode_canonical_json returns bytes.
new_key_json = encode_canonical_json(device_keys).decode("utf-8")
if old_key_json == new_key_json:
return False
self._simple_upsert_txn(
txn,
table="e2e_device_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
values={
"ts_added_ms": time_now,
"key_json": new_key_json,
}
)
return True
return self.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn
)
def claim_e2e_one_time_keys(self, query_list): def claim_e2e_one_time_keys(self, query_list):
"""Take a list of one time keys out of the database""" """Take a list of one time keys out of the database"""
def _claim_e2e_one_time_keys(txn): def _claim_e2e_one_time_keys(txn):