Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2017-03-15 16:08:46 +00:00
commit b6b1382be1
14 changed files with 152 additions and 59 deletions

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, RoomID from synapse.types import UserID, RoomID
from twisted.internet import defer from twisted.internet import defer
@ -253,19 +254,35 @@ class Filter(object):
Returns: Returns:
bool: True if the event matches bool: True if the event matches
""" """
sender = event.get("sender", None) # We usually get the full "events" as dictionaries coming through,
if not sender: # except for presence which actually gets passed around as its own
# Presence events have their 'sender' in content.user_id # namedtuple type.
content = event.get("content") if isinstance(event, UserPresenceState):
# account_data has been allowed to have non-dict content, so check type first sender = event.user_id
if isinstance(content, dict): room_id = None
sender = content.get("user_id") ev_type = "m.presence"
is_url = False
else:
sender = event.get("sender", None)
if not sender:
# Presence events had their 'sender' in content.user_id, but are
# now handled above. We don't know if anything else uses this
# form. TODO: Check this and probably remove it.
content = event.get("content")
# account_data has been allowed to have non-dict content, so
# check type first
if isinstance(content, dict):
sender = content.get("user_id")
room_id = event.get("room_id", None)
ev_type = event.get("type", None)
is_url = "url" in event.get("content", {})
return self.check_fields( return self.check_fields(
event.get("room_id", None), room_id,
sender, sender,
event.get("type", None), ev_type,
"url" in event.get("content", {}) is_url,
) )
def check_fields(self, room_id, sender, event_type, contains_url): def check_fields(self, room_id, sender, event_type, contains_url):

View file

@ -99,7 +99,12 @@ class TransactionQueue(object):
# destination -> list of tuple(failure, deferred) # destination -> list of tuple(failure, deferred)
self.pending_failures_by_dest = {} self.pending_failures_by_dest = {}
# destination -> stream_id of last successfully sent to-device message.
# NB: may be a long or an int.
self.last_device_stream_id_by_dest = {} self.last_device_stream_id_by_dest = {}
# destination -> stream_id of last successfully sent device list
# update.
self.last_device_list_stream_id_by_dest = {} self.last_device_list_stream_id_by_dest = {}
# HACK to get unique tx id # HACK to get unique tx id

View file

@ -19,6 +19,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes from synapse.api.errors import AuthError, Codes
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import ( from synapse.types import (
UserID, StreamToken, UserID, StreamToken,
@ -225,9 +226,17 @@ class InitialSyncHandler(BaseHandler):
"content": content, "content": content,
}) })
now = self.clock.time_msec()
ret = { ret = {
"rooms": rooms_ret, "rooms": rooms_ret,
"presence": presence, "presence": [
{
"type": "m.presence",
"content": format_user_presence_state(event, now),
}
for event in presence
],
"account_data": account_data_events, "account_data": account_data_events,
"receipts": receipt, "receipts": receipt,
"end": now_token.to_string(), "end": now_token.to_string(),

View file

@ -29,6 +29,7 @@ from synapse.api.errors import SynapseError
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.storage.presence import UserPresenceState from synapse.storage.presence import UserPresenceState
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -722,9 +723,7 @@ class PresenceHandler(object):
for state in updates for state in updates
]) ])
else: else:
defer.returnValue([ defer.returnValue(updates)
format_user_presence_state(state, now) for state in updates
])
@defer.inlineCallbacks @defer.inlineCallbacks
def set_state(self, target_user, state, ignore_status_msg=False): def set_state(self, target_user, state, ignore_status_msg=False):
@ -798,6 +797,9 @@ class PresenceHandler(object):
as_event=False, as_event=False,
) )
now = self.clock.time_msec()
results[:] = [format_user_presence_state(r, now) for r in results]
is_accepted = { is_accepted = {
row["observed_user_id"]: row["accepted"] for row in presence_list row["observed_user_id"]: row["accepted"] for row in presence_list
} }
@ -850,6 +852,7 @@ class PresenceHandler(object):
) )
state_dict = yield self.get_state(observed_user, as_event=False) state_dict = yield self.get_state(observed_user, as_event=False)
state_dict = format_user_presence_state(state_dict, self.clock.time_msec())
self.federation.send_edu( self.federation.send_edu(
destination=observer_user.domain, destination=observer_user.domain,
@ -982,14 +985,18 @@ def should_notify(old_state, new_state):
return False return False
def format_user_presence_state(state, now): def format_user_presence_state(state, now, include_user_id=True):
"""Convert UserPresenceState to a format that can be sent down to clients """Convert UserPresenceState to a format that can be sent down to clients
and to other servers. and to other servers.
The "user_id" is optional so that this function can be used to format presence
updates for client /sync responses and for federation /send requests.
""" """
content = { content = {
"presence": state.state, "presence": state.state,
"user_id": state.user_id,
} }
if include_user_id:
content["user_id"] = state.user_id
if state.last_active_ts: if state.last_active_ts:
content["last_active_ago"] = now - state.last_active_ts content["last_active_ago"] = now - state.last_active_ts
if state.status_msg and state.state != PresenceState.OFFLINE: if state.status_msg and state.state != PresenceState.OFFLINE:
@ -1028,7 +1035,6 @@ class PresenceEventSource(object):
# sending down the rare duplicate is not a concern. # sending down the rare duplicate is not a concern.
with Measure(self.clock, "presence.get_new_events"): with Measure(self.clock, "presence.get_new_events"):
user_id = user.to_string()
if from_key is not None: if from_key is not None:
from_key = int(from_key) from_key = int(from_key)
@ -1037,18 +1043,7 @@ class PresenceEventSource(object):
max_token = self.store.get_current_presence_token() max_token = self.store.get_current_presence_token()
plist = yield self.store.get_presence_list_accepted(user.localpart) users_interested_in = yield self._get_interested_in(user, explicit_room_id)
users_interested_in = set(row["observed_user_id"] for row in plist)
users_interested_in.add(user_id) # So that we receive our own presence
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
user_id
)
users_interested_in.update(users_who_share_room)
if explicit_room_id:
user_ids = yield self.store.get_users_in_room(explicit_room_id)
users_interested_in.update(user_ids)
user_ids_changed = set() user_ids_changed = set()
changed = None changed = None
@ -1076,16 +1071,13 @@ class PresenceEventSource(object):
updates = yield presence.current_state_for_users(user_ids_changed) updates = yield presence.current_state_for_users(user_ids_changed)
now = self.clock.time_msec() if include_offline:
defer.returnValue((updates.values(), max_token))
defer.returnValue(([ else:
{ defer.returnValue(([
"type": "m.presence", s for s in updates.itervalues()
"content": format_user_presence_state(s, now), if s.state != PresenceState.OFFLINE
} ], max_token))
for s in updates.values()
if include_offline or s.state != PresenceState.OFFLINE
], max_token))
def get_current_key(self): def get_current_key(self):
return self.store.get_current_presence_token() return self.store.get_current_presence_token()
@ -1093,6 +1085,31 @@ class PresenceEventSource(object):
def get_pagination_rows(self, user, pagination_config, key): def get_pagination_rows(self, user, pagination_config, key):
return self.get_new_events(user, from_key=None, include_offline=False) return self.get_new_events(user, from_key=None, include_offline=False)
@cachedInlineCallbacks(num_args=2, cache_context=True)
def _get_interested_in(self, user, explicit_room_id, cache_context):
"""Returns the set of users that the given user should see presence
updates for
"""
user_id = user.to_string()
plist = yield self.store.get_presence_list_accepted(
user.localpart, on_invalidate=cache_context.invalidate,
)
users_interested_in = set(row["observed_user_id"] for row in plist)
users_interested_in.add(user_id) # So that we receive our own presence
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
user_id, on_invalidate=cache_context.invalidate,
)
users_interested_in.update(users_who_share_room)
if explicit_room_id:
user_ids = yield self.store.get_users_in_room(
explicit_room_id, on_invalidate=cache_context.invalidate,
)
users_interested_in.update(user_ids)
defer.returnValue(users_interested_in)
def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now): def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
"""Checks the presence of users that have timed out and updates as """Checks the presence of users that have timed out and updates as

View file

@ -721,14 +721,14 @@ class SyncHandler(object):
extra_users_ids.update(users) extra_users_ids.update(users)
extra_users_ids.discard(user.to_string()) extra_users_ids.discard(user.to_string())
states = yield self.presence_handler.get_states( if extra_users_ids:
extra_users_ids, states = yield self.presence_handler.get_states(
as_event=True, extra_users_ids,
) )
presence.extend(states) presence.extend(states)
# Deduplicate the presence entries so that there's at most one per user # Deduplicate the presence entries so that there's at most one per user
presence = {p["content"]["user_id"]: p for p in presence}.values() presence = {p.user_id: p for p in presence}.values()
presence = sync_config.filter_collection.filter_presence( presence = sync_config.filter_collection.filter_presence(
presence presence

View file

@ -16,6 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.handlers.presence import format_user_presence_state
from synapse.util import DeferredTimedOutError from synapse.util import DeferredTimedOutError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -412,6 +413,15 @@ class Notifier(object):
new_events, new_events,
is_peeking=is_peeking, is_peeking=is_peeking,
) )
elif name == "presence":
now = self.clock.time_msec()
new_events[:] = [
{
"type": "m.presence",
"content": format_user_presence_state(event, now),
}
for event in new_events
]
events.extend(new_events) events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key) end_token = end_token.copy_and_replace(keyname, new_key)

View file

@ -27,4 +27,9 @@ class SlavedIdTracker(object):
self._current = (max if self.step > 0 else min)(self._current, new_id) self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self): def get_current_token(self):
"""
Returns:
int
"""
return self._current return self._current

View file

@ -34,6 +34,8 @@ from saml2.client import Saml2Client
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from twisted.web.client import PartialDownloadError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -417,7 +419,12 @@ class CasTicketServlet(ClientV1RestServlet):
"ticket": request.args["ticket"], "ticket": request.args["ticket"],
"service": self.cas_service_url "service": self.cas_service_url
} }
body = yield http_client.get_raw(uri, args) try:
body = yield http_client.get_raw(uri, args)
except PartialDownloadError as pde:
# Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data
body = pde.response
result = yield self.handle_cas_response(request, body, client_redirect_url) result = yield self.handle_cas_response(request, body, client_redirect_url)
defer.returnValue(result) defer.returnValue(result)

View file

@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError
from synapse.types import UserID from synapse.types import UserID
from synapse.handlers.presence import format_user_presence_state
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -33,6 +34,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(PresenceStatusRestServlet, self).__init__(hs) super(PresenceStatusRestServlet, self).__init__(hs)
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
@ -48,6 +50,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
raise AuthError(403, "You are not allowed to see their presence.") raise AuthError(403, "You are not allowed to see their presence.")
state = yield self.presence_handler.get_state(target_user=user) state = yield self.presence_handler.get_state(target_user=user)
state = format_user_presence_state(state, self.clock.time_msec())
defer.returnValue((200, state)) defer.returnValue((200, state))

View file

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, parse_string, parse_integer, parse_boolean RestServlet, parse_string, parse_integer, parse_boolean
) )
from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.sync import SyncConfig from synapse.handlers.sync import SyncConfig
from synapse.types import StreamToken from synapse.types import StreamToken
from synapse.events.utils import ( from synapse.events.utils import (
@ -28,7 +29,6 @@ from synapse.api.errors import SynapseError
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from ._base import client_v2_patterns from ._base import client_v2_patterns
import copy
import itertools import itertools
import logging import logging
@ -194,12 +194,18 @@ class SyncRestServlet(RestServlet):
defer.returnValue((200, response_content)) defer.returnValue((200, response_content))
def encode_presence(self, events, time_now): def encode_presence(self, events, time_now):
formatted = [] return {
for event in events: "events": [
event = copy.deepcopy(event) {
event['sender'] = event['content'].pop('user_id') "type": "m.presence",
formatted.append(event) "sender": event.user_id,
return {"events": formatted} "content": format_user_presence_state(
event, time_now, include_user_id=False
),
}
for event in events
]
}
def encode_joined(self, rooms, time_now, token_id, event_fields): def encode_joined(self, rooms, time_now, token_id, event_fields):
""" """

View file

@ -357,12 +357,12 @@ class DeviceInboxStore(BackgroundUpdateStore):
""" """
Args: Args:
destination(str): The name of the remote server. destination(str): The name of the remote server.
last_stream_id(int): The last position of the device message stream last_stream_id(int|long): The last position of the device message stream
that the server sent up to. that the server sent up to.
current_stream_id(int): The current position of the device current_stream_id(int|long): The current position of the device
message stream. message stream.
Returns: Returns:
Deferred ([dict], int): List of messages for the device and where Deferred ([dict], int|long): List of messages for the device and where
in the stream the messages got to. in the stream the messages got to.
""" """

View file

@ -308,7 +308,7 @@ class DeviceStore(SQLBaseStore):
"""Get stream of updates to send to remote servers """Get stream of updates to send to remote servers
Returns: Returns:
(now_stream_id, [ { updates }, .. ]) (int, list[dict]): current stream id and list of updates
""" """
now_stream_id = self._device_list_id_gen.get_current_token() now_stream_id = self._device_list_id_gen.get_current_token()

View file

@ -30,6 +30,17 @@ class IdGenerator(object):
def _load_current_id(db_conn, table, column, step=1): def _load_current_id(db_conn, table, column, step=1):
"""
Args:
db_conn (object):
table (str):
column (str):
step (int):
Returns:
int
"""
cur = db_conn.cursor() cur = db_conn.cursor()
if step == 1: if step == 1:
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
@ -131,6 +142,9 @@ class StreamIdGenerator(object):
def get_current_token(self): def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or """Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted. equal to it have been successfully persisted.
Returns:
int
""" """
with self._lock: with self._lock:
if self._unfinished_ids: if self._unfinished_ids:

View file

@ -50,7 +50,7 @@ class StreamChangeCache(object):
def has_entity_changed(self, entity, stream_pos): def has_entity_changed(self, entity, stream_pos):
"""Returns True if the entity may have been updated since stream_pos """Returns True if the entity may have been updated since stream_pos
""" """
assert type(stream_pos) is int assert type(stream_pos) is int or type(stream_pos) is long
if stream_pos < self._earliest_known_stream_pos: if stream_pos < self._earliest_known_stream_pos:
self.metrics.inc_misses() self.metrics.inc_misses()