Fix replication. And notify

This commit is contained in:
Erik Johnston 2017-07-20 17:13:18 +01:00
parent 139fe30f47
commit 2cc998fed8
5 changed files with 119 additions and 4 deletions

View file

@ -41,6 +41,7 @@ from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.groups import SlavedGroupServerStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
@ -75,6 +76,7 @@ class SynchrotronSlavedStore(
SlavedRegistrationStore, SlavedRegistrationStore,
SlavedFilteringStore, SlavedFilteringStore,
SlavedPresenceStore, SlavedPresenceStore,
SlavedGroupServerStore,
SlavedDeviceInboxStore, SlavedDeviceInboxStore,
SlavedDeviceStore, SlavedDeviceStore,
SlavedClientIpStore, SlavedClientIpStore,
@ -409,6 +411,10 @@ class SyncReplicationHandler(ReplicationClientHandler):
) )
elif stream_name == "presence": elif stream_name == "presence":
yield self.presence_handler.process_replication_rows(token, rows) yield self.presence_handler.process_replication_rows(token, rows)
elif stream_name == "receipts":
self.notifier.on_new_event(
"groups_key", token, users=[row.user_id for row in rows],
)
def start(config_options): def start(config_options):

View file

@ -211,13 +211,16 @@ class GroupsLocalHandler(object):
user_id=user_id, user_id=user_id,
) )
yield self.store.register_user_group_membership( token = yield self.store.register_user_group_membership(
group_id, user_id, group_id, user_id,
membership="join", membership="join",
is_admin=False, is_admin=False,
local_attestation=local_attestation, local_attestation=local_attestation,
remote_attestation=remote_attestation, remote_attestation=remote_attestation,
) )
self.notifier.on_new_event(
"groups_key", token, users=[user_id],
)
defer.returnValue({}) defer.returnValue({})
@ -257,11 +260,14 @@ class GroupsLocalHandler(object):
if "avatar_url" in content["profile"]: if "avatar_url" in content["profile"]:
local_profile["avatar_url"] = content["profile"]["avatar_url"] local_profile["avatar_url"] = content["profile"]["avatar_url"]
yield self.store.register_user_group_membership( token = yield self.store.register_user_group_membership(
group_id, user_id, group_id, user_id,
membership="invite", membership="invite",
content={"profile": local_profile, "inviter": content["inviter"]}, content={"profile": local_profile, "inviter": content["inviter"]},
) )
self.notifier.on_new_event(
"groups_key", token, users=[user_id],
)
defer.returnValue({"state": "invite"}) defer.returnValue({"state": "invite"})
@ -270,10 +276,13 @@ class GroupsLocalHandler(object):
"""Remove a user from a group """Remove a user from a group
""" """
if user_id == requester_user_id: if user_id == requester_user_id:
yield self.store.register_user_group_membership( token = yield self.store.register_user_group_membership(
group_id, user_id, group_id, user_id,
membership="leave", membership="leave",
) )
self.notifier.on_new_event(
"groups_key", token, users=[user_id],
)
# TODO: Should probably remember that we tried to leave so that we can # TODO: Should probably remember that we tried to leave so that we can
# retry if the group server is currently down. # retry if the group server is currently down.
@ -296,10 +305,13 @@ class GroupsLocalHandler(object):
"""One of our users was removed/kicked from a group """One of our users was removed/kicked from a group
""" """
# TODO: Check if user in group # TODO: Check if user in group
yield self.store.register_user_group_membership( token = yield self.store.register_user_group_membership(
group_id, user_id, group_id, user_id,
membership="leave", membership="leave",
) )
self.notifier.on_new_event(
"groups_key", token, users=[user_id],
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_joined_groups(self, user_id): def get_joined_groups(self, user_id):

View file

@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedGroupServerStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedGroupServerStore, self).__init__(db_conn, hs)
self.hs = hs
self._group_updates_id_gen = SlavedIdTracker(
db_conn, "local_group_updates", "stream_id",
)
self._group_updates_stream_cache = StreamChangeCache(
"_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(),
)
get_groups_changes_for_user = DataStore.get_groups_changes_for_user.__func__
get_group_stream_token = DataStore.get_group_stream_token.__func__
get_all_groups_for_user = DataStore.get_all_groups_for_user.__func__
def stream_positions(self):
result = super(SlavedGroupServerStore, self).stream_positions()
result["groups"] = self._group_updates_id_gen.get_current_token()
return result
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "groups":
self._group_updates_id_gen.advance(token)
for row in rows:
self._group_updates_stream_cache.entity_has_changed(
row.user_id, token
)
return super(SlavedGroupServerStore, self).process_replication_rows(
stream_name, token, rows
)

View file

@ -118,6 +118,12 @@ CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", (
"state_key", # str "state_key", # str
"event_id", # str, optional "event_id", # str, optional
)) ))
GroupsStreamRow = namedtuple("GroupsStreamRow", (
"group_id", # str
"user_id", # str
"type", # str
"content", # dict
))
class Stream(object): class Stream(object):
@ -464,6 +470,19 @@ class CurrentStateDeltaStream(Stream):
super(CurrentStateDeltaStream, self).__init__(hs) super(CurrentStateDeltaStream, self).__init__(hs)
class GroupServerStream(Stream):
NAME = "groups"
ROW_TYPE = GroupsStreamRow
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_group_stream_token
self.update_function = store.get_all_groups_changes
super(GroupServerStream, self).__init__(hs)
STREAMS_MAP = { STREAMS_MAP = {
stream.NAME: stream stream.NAME: stream
for stream in ( for stream in (
@ -482,5 +501,6 @@ STREAMS_MAP = {
TagAccountDataStream, TagAccountDataStream,
AccountDataStream, AccountDataStream,
CurrentStateDeltaStream, CurrentStateDeltaStream,
GroupServerStream,
) )
} }

View file

@ -853,6 +853,8 @@ class GroupServerStore(SQLBaseStore):
}, },
) )
return next_id
with self._group_updates_id_gen.get_next() as next_id: with self._group_updates_id_gen.get_next() as next_id:
yield self.runInteraction( yield self.runInteraction(
"register_user_group_membership", "register_user_group_membership",
@ -993,5 +995,26 @@ class GroupServerStore(SQLBaseStore):
"get_groups_changes_for_user", _get_groups_changes_for_user_txn, "get_groups_changes_for_user", _get_groups_changes_for_user_txn,
) )
def get_all_groups_changes(self, from_token, to_token, limit):
from_token = int(from_token)
has_changed = self._group_updates_stream_cache.has_any_entity_changed(
from_token,
)
if not has_changed:
return []
def _get_all_groups_changes_txn(txn):
sql = """
SELECT stream_id, group_id, user_id, type, content
FROM local_group_updates
WHERE ? < stream_id AND stream_id <= ?
LIMIT ?
"""
txn.execute(sql, (from_token, to_token, limit,))
return txn.fetchall()
return self.runInteraction(
"get_all_groups_changes", _get_all_groups_changes_txn,
)
def get_group_stream_token(self): def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token() return self._group_updates_id_gen.get_current_token()