Moar stuff

This commit is contained in:
Erik Johnston 2017-04-24 09:46:24 +01:00
parent 3033261891
commit d4cb3edba8
8 changed files with 258 additions and 36 deletions

View file

@ -24,7 +24,6 @@ from synapse.util.async import run_on_reactor
from synapse.util.logcontext import preserve_context_over_fn, preserve_fn from synapse.util.logcontext import preserve_context_over_fn, preserve_fn
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.types import get_domain_from_id
from synapse.handlers.presence import format_user_presence_state, get_interested_remotes from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
import synapse.metrics import synapse.metrics
@ -183,15 +182,12 @@ class TransactionQueue(object):
# Otherwise if the last member on a server in a room is # Otherwise if the last member on a server in a room is
# banned then it won't receive the event because it won't # banned then it won't receive the event because it won't
# be in the room after the ban. # be in the room after the ban.
users_in_room = yield self.state.get_current_user_in_room( destinations = yield self.state.get_current_hosts_in_room(
event.room_id, latest_event_ids=[ event.room_id, latest_event_ids=[
prev_id for prev_id, _ in event.prev_events prev_id for prev_id, _ in event.prev_events
], ],
) )
destinations = set(
get_domain_from_id(user_id) for user_id in users_in_room
)
if send_on_behalf_of is not None: if send_on_behalf_of is not None:
# If we are sending the event on behalf of another server # If we are sending the event on behalf of another server
# then it already has the event and there is no reason to # then it already has the event and there is no reason to

View file

@ -19,17 +19,21 @@ from twisted.internet import defer
from .push_rule_evaluator import PushRuleEvaluatorForEvent from .push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes, Membership
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
rules_by_room = {}
@defer.inlineCallbacks @defer.inlineCallbacks
def evaluator_for_event(event, hs, store, context): def evaluator_for_event(event, hs, store, context):
rules_by_user = yield store.bulk_get_push_rules_for_room( room_id = event.room_id
event, context rules_for_room = rules_by_room.setdefault(room_id, RulesForRoom(hs, room_id))
)
rules_by_user = yield rules_for_room.get_rules(context)
# if this event is an invite event, we may need to run rules for the user # if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited # who's been invited, otherwise they won't get told they've been invited
@ -66,11 +70,14 @@ class BulkPushRuleEvaluator:
def action_for_event_by_user(self, event, context): def action_for_event_by_user(self, event, context):
actions_by_user = {} actions_by_user = {}
room_members = yield self.store.get_joined_users_from_context( # room_members = yield self.store.get_joined_users_from_context(
event, context # event, context
) # )
room_members = {}
evaluator = PushRuleEvaluatorForEvent(event, len(room_members)) num_room_members = yield self.store.get_number_of_users_in_rooms(event.room_id)
evaluator = PushRuleEvaluatorForEvent(event, num_room_members)
condition_cache = {} condition_cache = {}
@ -127,3 +134,128 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache):
return False return False
return True return True
class RulesForRoom(object):
def __init__(self, hs, room_id):
self.room_id = room_id
self.is_mine_id = hs.is_mine_id
self.store = hs.get_datastore()
self.member_map = {} # event_id -> (user_id, state)
self.rules_by_user = {} # user_id -> rules
self.state_group = object()
self.sequence = 0
@defer.inlineCallbacks
def get_rules(self, context):
state_group = context.state_group
current_state_ids = context.current_state_ids
if state_group and self.state_group == state_group:
defer.returnValue(self.rules_by_user)
ret_rules_by_user = {}
missing_member_event_ids = {}
for key, event_id in current_state_ids.iteritems():
res = self.member_map.get(event_id, None)
if res:
user_id, state = res
if state == Membership.JOIN:
rules = self.rules_by_user.get(user_id, None)
if rules:
ret_rules_by_user[user_id] = rules
continue
if key[0] != EventTypes.Member:
continue
user_id = key[1]
if not self.is_mine_id(user_id):
continue
if self.store.get_if_app_services_interested_in_user(user_id):
continue
missing_member_event_ids[user_id] = event_id
if missing_member_event_ids:
missing_rules = yield self.get_rules_for_member_event_ids(
missing_member_event_ids, state_group
)
ret_rules_by_user.update(missing_rules)
defer.returnValue(ret_rules_by_user)
@defer.inlineCallbacks
def get_rules_for_member_event_ids(self, member_event_ids, state_group):
sequence = self.sequence
rows = yield self.store._simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids.values(),
retcols=['user_id', 'membership', 'event_id'],
keyvalues={},
batch_size=500,
desc="get_rules_for_member_event_ids",
)
members = {
row["event_id"]: (row["user_id"], row["membership"])
for row in rows
}
interested_in_user_ids = set(user_id for user_id, _ in members.itervalues())
if_users_with_pushers = yield self.store.get_if_users_have_pushers(
interested_in_user_ids,
on_invalidate=self.invalidate_all,
)
user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
)
users_with_receipts = yield self.store.get_users_with_read_receipts_in_room(
self.room_id, on_invalidate=self.invalidate_all,
)
# any users with pushers must be ours: they have pushers
for uid in users_with_receipts:
if uid in interested_in_user_ids:
user_ids.add(uid)
forgotten = yield self.store.who_forgot_in_room(
self.room_id, on_invalidate=self.invalidate_all,
)
for row in forgotten:
user_id = row["user_id"]
event_id = row["event_id"]
mem_id = member_event_ids.get((user_id), None)
if event_id == mem_id:
user_ids.discard(user_id)
rules_by_user = yield self.store.bulk_get_push_rules(
user_ids, on_invalidate=self.invalidate_all,
)
rules_by_user = {k: v for k, v in rules_by_user.iteritems() if v is not None}
self.update_cache(sequence, members, rules_by_user, state_group)
defer.returnValue(rules_by_user)
def invalidate_all(self):
self.sequence += 1
self.member_map = {}
self.rules_by_user = {}
def update_cache(self, sequence, members, rules_by_user, state_group):
if sequence == self.sequence:
self.member_map.update(members)
self.rules_by_user.update(rules_by_user)
self.state_group = state_group

View file

@ -175,6 +175,17 @@ class StateHandler(object):
) )
defer.returnValue(joined_users) defer.returnValue(joined_users)
@defer.inlineCallbacks
def get_current_hosts_in_room(self, room_id, latest_event_ids=None):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_user_in_room")
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_hosts = yield self.store.get_joined_hosts(
room_id, entry.state_id, entry.state
)
defer.returnValue(joined_hosts)
@defer.inlineCallbacks @defer.inlineCallbacks
def compute_event_context(self, event, old_state=None): def compute_event_context(self, event, old_state=None):
"""Build an EventContext structure for the event. """Build an EventContext structure for the event.

View file

@ -60,12 +60,12 @@ class LoggingTransaction(object):
object.__setattr__(self, "database_engine", database_engine) object.__setattr__(self, "database_engine", database_engine)
object.__setattr__(self, "after_callbacks", after_callbacks) object.__setattr__(self, "after_callbacks", after_callbacks)
def call_after(self, callback, *args): def call_after(self, callback, *args, **kwargs):
"""Call the given callback on the main twisted thread after the """Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the transaction has finished. Used to invalidate the caches on the
correct thread. correct thread.
""" """
self.after_callbacks.append((callback, args)) self.after_callbacks.append((callback, args, kwargs))
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.txn, name) return getattr(self.txn, name)
@ -319,8 +319,8 @@ class SQLBaseStore(object):
inner_func, *args, **kwargs inner_func, *args, **kwargs
) )
finally: finally:
for after_callback, after_args in after_callbacks: for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args) after_callback(*after_args, **after_kwargs)
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -370,6 +370,10 @@ class EventsStore(SQLBaseStore):
new_forward_extremeties=new_forward_extremeties, new_forward_extremeties=new_forward_extremeties,
) )
persist_event_counter.inc_by(len(chunk)) persist_event_counter.inc_by(len(chunk))
for room_id, (_, _, new_state) in current_state_for_room.iteritems():
self.get_current_state_ids.prefill(
(room_id, ), new_state
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids): def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids):
@ -529,7 +533,7 @@ class EventsStore(SQLBaseStore):
if ev_id in events_to_insert if ev_id in events_to_insert
} }
defer.returnValue((to_delete, to_insert)) defer.returnValue((to_delete, to_insert, current_state))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True, def get_event(self, event_id, check_redacted=True,
@ -682,7 +686,7 @@ class EventsStore(SQLBaseStore):
def _update_current_state_txn(self, txn, state_delta_by_room): def _update_current_state_txn(self, txn, state_delta_by_room):
for room_id, current_state_tuple in state_delta_by_room.iteritems(): for room_id, current_state_tuple in state_delta_by_room.iteritems():
to_delete, to_insert = current_state_tuple to_delete, to_insert, _ = current_state_tuple
txn.executemany( txn.executemany(
"DELETE FROM current_state_events WHERE event_id = ?", "DELETE FROM current_state_events WHERE event_id = ?",
[(ev_id,) for ev_id in to_delete.itervalues()], [(ev_id,) for ev_id in to_delete.itervalues()],

View file

@ -48,7 +48,7 @@ def _load_rules(rawrules, enabled_map):
class PushRuleStore(SQLBaseStore): class PushRuleStore(SQLBaseStore):
@cachedInlineCallbacks() @cachedInlineCallbacks(max_entries=50000)
def get_push_rules_for_user(self, user_id): def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list( rows = yield self._simple_select_list(
table="push_rules", table="push_rules",

View file

@ -417,25 +417,47 @@ class RoomMemberStore(SQLBaseStore):
if key[0] == EventTypes.Member if key[0] == EventTypes.Member
] ]
rows = yield self._simple_select_many_batch( event_map = self._get_events_from_cache(
table="room_memberships", member_event_ids,
column="event_id", allow_rejected=False,
iterable=member_event_ids,
retcols=['user_id', 'display_name', 'avatar_url'],
keyvalues={
"membership": Membership.JOIN,
},
batch_size=500,
desc="_get_joined_users_from_context",
) )
users_in_room = { missing_member_event_ids = []
to_ascii(row["user_id"]): ProfileInfo( users_in_room = {}
avatar_url=to_ascii(row["avatar_url"]), for event_id, ev_entry in event_map.iteritems():
display_name=to_ascii(row["display_name"]), if event_id:
if ev_entry.event.membership == Membership.JOIN:
users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo(
display_name=to_ascii(
ev_entry.event.content.get("displayname", None)
),
avatar_url=to_ascii(
ev_entry.event.content.get("avatar_url", None)
),
)
else:
missing_member_event_ids.append(event_id)
if missing_member_event_ids:
rows = yield self._simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
retcols=('user_id', 'display_name', 'avatar_url',),
keyvalues={
"membership": Membership.JOIN,
},
batch_size=500,
desc="_get_joined_users_from_context",
) )
for row in rows
} users_in_room.update({
to_ascii(row["user_id"]): ProfileInfo(
avatar_url=to_ascii(row["avatar_url"]),
display_name=to_ascii(row["display_name"]),
)
for row in rows
})
if event is not None and event.type == EventTypes.Member: if event is not None and event.type == EventTypes.Member:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
@ -447,6 +469,17 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(users_in_room) defer.returnValue(users_in_room)
def get_number_of_users_in_rooms(self, room_id):
sql = """SELECT coalesce(count(*), 0) FROM current_state_events
INNER JOIN room_memberships USING (room_id, event_id)
WHERE room_id = ? AND membership = 'join'
"""
return self._execute(
"get_number_of_users_in_rooms",
lambda txn: txn.fetchone()[0],
sql, room_id
)
def is_host_joined(self, room_id, host, state_group, state_ids): def is_host_joined(self, room_id, host, state_group, state_ids):
if not state_group: if not state_group:
# If state_group is None it means it has yet to be assigned a # If state_group is None it means it has yet to be assigned a
@ -482,6 +515,44 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(False) defer.returnValue(False)
def get_joined_hosts(self, room_id, state_group, state_ids):
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()
return self._get_joined_hosts(
room_id, state_group, state_ids
)
@cachedInlineCallbacks(num_args=3)
def _get_joined_hosts(self, room_id, state_group, current_state_ids):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
joined_hosts = set()
for (etype, state_key), event_id in current_state_ids.items():
if etype == EventTypes.Member:
try:
host = get_domain_from_id(state_key)
except:
logger.warn("state_key not user_id: %s", state_key)
continue
if host in joined_hosts:
continue
event = yield self.get_event(event_id, allow_none=True)
if event and event.content["membership"] == Membership.JOIN:
joined_hosts.add(host)
defer.returnValue(joined_hosts)
@defer.inlineCallbacks @defer.inlineCallbacks
def _background_add_membership_profile(self, progress, batch_size): def _background_add_membership_profile(self, progress, batch_size):
target_min_stream_id = progress.get( target_min_stream_id = progress.get(

View file

@ -227,6 +227,14 @@ class StateStore(SQLBaseStore):
], ],
) )
txn.call_after(
self._state_group_cache.update,
self._state_group_cache.sequence,
key=context.state_group,
value=context.current_state_ids,
full=True,
)
self._simple_insert_many_txn( self._simple_insert_many_txn(
txn, txn,
table="event_to_state_groups", table="event_to_state_groups",