Merge branch 'develop' into server2server_signing

This commit is contained in:
Mark Haines 2014-09-22 18:54:00 +01:00
commit 09d79b0a9b
151 changed files with 14760 additions and 4762 deletions

3
.gitignore vendored
View file

@ -24,4 +24,7 @@ graph/*.svg
graph/*.png graph/*.png
graph/*.dot graph/*.dot
webclient/config.js
webclient/test/environment-protractor.js
uploads uploads

View file

@ -1,3 +1,119 @@
Changes in synapse 0.3.3 (2014-09-22)
=====================================
Homeserver:
* Fix bug where you continued to get events for rooms you had left.
Webclient:
* Add support for video calls with basic UI.
* Fix bug where one to one chats were named after your display name rather
than the other person's.
* Fix bug which caused lag when typing in the textarea.
* Refuse to run on browsers we know won't work.
* Trigger pagination when joining new rooms.
* Fix bug where we sometimes didn't display invitations in recents.
* Automatically join room when accepting a VoIP call.
* Disable outgoing and reject incoming calls on browsers we don't support
VoIP in.
* Don't display desktop notifications for messages in the room you are
non-idle and speaking in.
Changes in synapse 0.3.2 (2014-09-18)
=====================================
Webclient:
* Fix bug where an empty "bing words" list in old accounts didn't send
notifications when it should have done.
Changes in synapse 0.3.1 (2014-09-18)
=====================================
This is a release to hotfix v0.3.0 to fix two regressions.
Webclient:
* Fix a regression where we sometimes displayed duplicate events.
* Fix a regression where we didn't immediately remove rooms you were
banned in from the recents list.
Changes in synapse 0.3.0 (2014-09-18)
=====================================
See UPGRADE for information about changes to the client server API, including
breaking backwards compatibility with VoIP calls and registration API.
Homeserver:
* When a user changes their displayname or avatar the server will now update
all their join states to reflect this.
* The server now adds "age" key to events to indicate how old they are. This
is clock independent, so at no point does any server or webclient have to
assume their clock is in sync with everyone else.
* Fix bug where we didn't correctly pull in missing PDUs.
* Fix bug where prev_content key wasn't always returned.
* Add support for password resets.
Webclient:
* Improve page content loading.
* Join/parts now trigger desktop notifications.
* Always show room aliases in the UI if one is present.
* No longer show user-count in the recents side panel.
* Add up & down arrow support to the text box for message sending to step
through your sent history.
* Don't display notifications for our own messages.
* Emotes are now formatted correctly in desktop notifications.
* The recents list now differentiates between public & private rooms.
* Fix bug where when switching between rooms the pagination flickered before
the view jumped to the bottom of the screen.
* Add bing word support.
Registration API:
* The registration API has been overhauled to function like the login API. In
practice, this means registration requests must now include the following:
'type':'m.login.password'. See UPGRADE for more information on this.
* The 'user_id' key has been renamed to 'user' to better match the login API.
* There is an additional login type: 'm.login.email.identity'.
* The command client and web client have been updated to reflect these changes.
Changes in synapse 0.2.3 (2014-09-12)
=====================================
Homeserver:
* Fix bug where we stopped sending events to remote home servers if a
user from that home server left, even if there were some still in the
room.
* Fix bugs in the state conflict resolution where it was incorrectly
rejecting events.
Webclient:
* Display room names and topics.
* Allow setting/editing of room names and topics.
* Display information about rooms on the main page.
* Handle ban and kick events in real time.
* VoIP UI and reliability improvements.
* Add glare support for VoIP.
* Improvements to initial startup speed.
* Don't display duplicate join events.
* Local echo of messages.
* Differentiate sending and sent of local echo.
* Various minor bug fixes.
Changes in synapse 0.2.2 (2014-09-06)
=====================================
Homeserver:
* When the server returns state events it now also includes the previous
content.
* Add support for inviting people when creating a new room.
* Make the homeserver inform the room via `m.room.aliases` when a new alias
is added for a room.
* Validate `m.room.power_level` events.
Webclient:
* Add support for captchas on registration.
* Handle `m.room.aliases` events.
* Asynchronously send messages and show a local echo.
* Inform the UI when a message failed to send.
* Only autoscroll on receiving a new message if the user was already at the
bottom of the screen.
* Add support for ban/kick reasons.
Changes in synapse 0.2.1 (2014-09-03) Changes in synapse 0.2.1 (2014-09-03)
===================================== =====================================

View file

@ -102,7 +102,7 @@ service provider in Matrix, unlike WhatsApp, Facebook, Hangouts, etc.
Synapse ships with two basic demo Matrix clients: webclient (a basic group chat Synapse ships with two basic demo Matrix clients: webclient (a basic group chat
web client demo implemented in AngularJS) and cmdclient (a basic Python web client demo implemented in AngularJS) and cmdclient (a basic Python
commandline utility which lets you easily see what the JSON APIs are up to). command line utility which lets you easily see what the JSON APIs are up to).
We'd like to invite you to take a look at the Matrix spec, try to run a We'd like to invite you to take a look at the Matrix spec, try to run a
homeserver, and join the existing Matrix chatrooms already out there, experiment homeserver, and join the existing Matrix chatrooms already out there, experiment
@ -122,7 +122,7 @@ Homeserver Installation
First, the dependencies need to be installed. Start by installing First, the dependencies need to be installed. Start by installing
'python2.7-dev' and the various tools of the compiler toolchain. 'python2.7-dev' and the various tools of the compiler toolchain.
Installing prerequisites on ubuntu:: Installing prerequisites on Ubuntu::
$ sudo apt-get install build-essential python2.7-dev libffi-dev $ sudo apt-get install build-essential python2.7-dev libffi-dev
@ -151,8 +151,8 @@ you can check PyNaCl out of git directly (https://github.com/pyca/pynacl) and
installing it. Installing PyNaCl using pip may also work (remember to remove any installing it. Installing PyNaCl using pip may also work (remember to remove any
other versions installed by setuputils in, for example, ~/.local/lib). other versions installed by setuputils in, for example, ~/.local/lib).
On OSX, if you encounter ``clang: error: unknown argument: '-mno-fused-madd'`` you will On OSX, if you encounter ``clang: error: unknown argument: '-mno-fused-madd'``
need to ``export CFLAGS=-Qunused-arguments``. you will need to ``export CFLAGS=-Qunused-arguments``.
This will run a process of downloading and installing into your This will run a process of downloading and installing into your
user's .local/lib directory all of the required dependencies that are user's .local/lib directory all of the required dependencies that are
@ -203,9 +203,10 @@ For the first form, simply pass the required hostname (of the machine) as the
--generate-config --generate-config
$ python synapse/app/homeserver.py --config-path homeserver.config $ python synapse/app/homeserver.py --config-path homeserver.config
Alternatively, you can run synapse via synctl - running ``synctl start`` to generate a Alternatively, you can run synapse via synctl - running ``synctl start`` to
homeserver.yaml config file, where you can then edit server-name to specify generate a homeserver.yaml config file, where you can then edit server-name to
machine.my.domain.name, and then set the actual server running again with synctl start. specify machine.my.domain.name, and then set the actual server running again
with synctl start.
For the second form, first create your SRV record and publish it in DNS. This For the second form, first create your SRV record and publish it in DNS. This
needs to be named _matrix._tcp.YOURDOMAIN, and point at at least one hostname needs to be named _matrix._tcp.YOURDOMAIN, and point at at least one hostname
@ -293,7 +294,8 @@ track 3PID logins and publish end-user public keys.
It's currently early days for identity servers as Matrix is not yet using 3PIDs It's currently early days for identity servers as Matrix is not yet using 3PIDs
as the primary means of identity and E2E encryption is not complete. As such, as the primary means of identity and E2E encryption is not complete. As such,
we are running a single identity server (http://matrix.org:8090) at the current time. we are running a single identity server (http://matrix.org:8090) at the current
time.
Where's the spec?! Where's the spec?!

View file

@ -1,3 +1,34 @@
Upgrading to v0.3.0
===================
This registration API now closely matches the login API. This introduces a bit
more backwards and forwards between the HS and the client, but this improves
the overall flexibility of the API. You can now GET on /register to retrieve a list
of valid registration flows. Upon choosing one, they are submitted in the same
way as login, e.g::
{
type: m.login.password,
user: foo,
password: bar
}
The default HS supports 2 flows, with and without Identity Server email
authentication. Enabling captcha on the HS will add in an extra step to all
flows: ``m.login.recaptcha`` which must be completed before you can transition
to the next stage. There is a new login type: ``m.login.email.identity`` which
contains the ``threepidCreds`` key which were previously sent in the original
register request. For more information on this, see the specification.
Web Client
----------
The VoIP specification has changed between v0.2.0 and v0.3.0. Users should
refresh any browser tabs to get the latest web client code. Users on
v0.2.0 of the web client will not be able to call those on v0.3.0 and
vice versa.
Upgrading to v0.2.0 Upgrading to v0.2.0
=================== ===================

View file

@ -1 +1 @@
0.2.1 0.3.3

View file

@ -5,3 +5,5 @@ Broad-sweeping stuff which would be nice to have
- homeserver implementation in go - homeserver implementation in go
- homeserver implementation in node.js - homeserver implementation in node.js
- client SDKs - client SDKs
- libpurple library
- irssi plugin?

View file

@ -145,35 +145,50 @@ class SynapseCmd(cmd.Cmd):
<noupdate> : Do not automatically clobber config values. <noupdate> : Do not automatically clobber config values.
""" """
args = self._parse(line, ["userid", "noupdate"]) args = self._parse(line, ["userid", "noupdate"])
path = "/register"
password = None password = None
pwd = None pwd = None
pwd2 = "_" pwd2 = "_"
while pwd != pwd2: while pwd != pwd2:
pwd = getpass.getpass("(Optional) Type a password for this user: ") pwd = getpass.getpass("Type a password for this user: ")
if len(pwd) == 0:
print "Not using a password for this user."
break
pwd2 = getpass.getpass("Retype the password: ") pwd2 = getpass.getpass("Retype the password: ")
if pwd != pwd2: if pwd != pwd2 or len(pwd) == 0:
print "Password mismatch." print "Password mismatch."
pwd = None
else: else:
password = pwd password = pwd
body = {} body = {
"type": "m.login.password"
}
if "userid" in args: if "userid" in args:
body["user_id"] = args["userid"] body["user"] = args["userid"]
if password: if password:
body["password"] = password body["password"] = password
reactor.callFromThread(self._do_register, "POST", path, body, reactor.callFromThread(self._do_register, body,
"noupdate" not in args) "noupdate" not in args)
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_register(self, method, path, data, update_config): def _do_register(self, data, update_config):
url = self._url() + path # check the registration flows
json_res = yield self.http_client.do_request(method, url, data=data) url = self._url() + "/register"
json_res = yield self.http_client.do_request("GET", url)
print json.dumps(json_res, indent=4)
passwordFlow = None
for flow in json_res["flows"]:
if flow["type"] == "m.login.recaptcha" or ("stages" in flow and "m.login.recaptcha" in flow["stages"]):
print "Unable to register: Home server requires captcha."
return
if flow["type"] == "m.login.password" and "stages" not in flow:
passwordFlow = flow
break
if not passwordFlow:
return
json_res = yield self.http_client.do_request("POST", url, data=data)
print json.dumps(json_res, indent=4) print json.dumps(json_res, indent=4)
if update_config and "user_id" in json_res: if update_config and "user_id" in json_res:
self.config["user"] = json_res["user_id"] self.config["user"] = json_res["user_id"]

View file

@ -24,7 +24,7 @@ If you already have an account, you must **login** into it.
`Try out the fiddle`__ `Try out the fiddle`__
.. __: http://jsfiddle.net/4q2jyxng/ .. __: http://jsfiddle.net/gh/get/jquery/1.8.3/matrix-org/synapse/tree/master/jsfiddles/register_login
Registration Registration
------------ ------------
@ -87,7 +87,7 @@ user and **send a message** to that room.
`Try out the fiddle`__ `Try out the fiddle`__
.. __: http://jsfiddle.net/zL3zto9g/ .. __: http://jsfiddle.net/gh/get/jquery/1.8.3/matrix-org/synapse/tree/master/jsfiddles/create_room_send_msg
Creating a room Creating a room
--------------- ---------------
@ -137,7 +137,7 @@ join a room **via a room alias** if one was set up.
`Try out the fiddle`__ `Try out the fiddle`__
.. __: http://jsfiddle.net/7fhotf1b/ .. __: http://jsfiddle.net/gh/get/jquery/1.8.3/matrix-org/synapse/tree/master/jsfiddles/room_memberships
Inviting a user to a room Inviting a user to a room
------------------------- -------------------------
@ -183,7 +183,7 @@ of getting events, depending on what the client already knows.
`Try out the fiddle`__ `Try out the fiddle`__
.. __: http://jsfiddle.net/vw11mg37/ .. __: http://jsfiddle.net/gh/get/jquery/1.8.3/matrix-org/synapse/tree/master/jsfiddles/event_stream
Getting all state Getting all state
----------------- -----------------
@ -633,4 +633,4 @@ application.
`Try out the fiddle`__ `Try out the fiddle`__
.. __: http://jsfiddle.net/uztL3yme/ .. __: http://jsfiddle.net/gh/get/jquery/1.8.3/matrix-org/synapse/tree/master/jsfiddles/example_app

View file

@ -3,35 +3,38 @@
"apis": [ "apis": [
{ {
"operations": [ "operations": [
{
"method": "GET",
"nickname": "get_registration_info",
"notes": "All login stages MUST be mentioned if there is >1 login type.",
"summary": "Get the login mechanism to use when registering.",
"type": "RegistrationFlows"
},
{ {
"method": "POST", "method": "POST",
"nickname": "register", "nickname": "submit_registration",
"notes": "Volatile: This API is likely to change.", "notes": "If this is part of a multi-stage registration, there MUST be a 'session' key.",
"parameters": [ "parameters": [
{ {
"description": "A registration request", "description": "A registration submission",
"name": "body", "name": "body",
"paramType": "body", "paramType": "body",
"required": true, "required": true,
"type": "RegistrationRequest" "type": "RegistrationSubmission"
} }
], ],
"responseMessages": [ "responseMessages": [
{ {
"code": 400, "code": 400,
"message": "No JSON object." "message": "Bad login type"
}, },
{ {
"code": 400, "code": 400,
"message": "User ID must only contain characters which do not require url encoding." "message": "Missing JSON keys"
},
{
"code": 400,
"message": "User ID already taken."
} }
], ],
"summary": "Register with the home server.", "summary": "Submit a registration action.",
"type": "RegistrationResponse" "type": "RegistrationResult"
} }
], ],
"path": "/register" "path": "/register"
@ -42,30 +45,68 @@
"application/json" "application/json"
], ],
"models": { "models": {
"RegistrationResponse": { "RegistrationFlows": {
"id": "RegistrationResponse", "id": "RegistrationFlows",
"properties": { "properties": {
"access_token": { "flows": {
"description": "The access token for this user.", "description": "A list of valid registration flows.",
"type": "string" "type": "array",
"items": {
"$ref": "RegistrationInfo"
}
}
}
},
"RegistrationInfo": {
"id": "RegistrationInfo",
"properties": {
"stages": {
"description": "Multi-stage registration only: An array of all the login types required to registration.",
"items": {
"$ref": "string"
},
"type": "array"
}, },
"user_id": { "type": {
"description": "The fully-qualified user ID.", "description": "The first login type that must be used when logging in.",
"type": "string"
},
"home_server": {
"description": "The name of the home server.",
"type": "string" "type": "string"
} }
} }
}, },
"RegistrationRequest": { "RegistrationResult": {
"id": "RegistrationRequest", "id": "RegistrationResult",
"properties": { "properties": {
"access_token": {
"description": "The access token for this user's registration if this is the final stage of the registration process.",
"type": "string"
},
"user_id": { "user_id": {
"description": "The desired user ID. If not specified, a random user ID will be allocated.", "description": "The user's fully-qualified user ID.",
"type": "string", "type": "string"
"required": false },
"next": {
"description": "Multi-stage registration only: The next registration type to submit.",
"type": "string"
},
"session": {
"description": "Multi-stage registration only: The session token to send when submitting the next registration type.",
"type": "string"
}
}
},
"RegistrationSubmission": {
"id": "RegistrationSubmission",
"properties": {
"type": {
"description": "The type of registration being submitted.",
"type": "string"
},
"session": {
"description": "Multi-stage registration only: The session token from an earlier registration stage.",
"type": "string"
},
"_registration_type_defined_keys_": {
"description": "Keys as defined by the specified registration type, e.g. \"user\", \"password\""
} }
} }
} }

1
docs/freenode.txt Normal file
View file

@ -0,0 +1 @@
NCjcRSEG

79
docs/human-id-rules.rst Normal file
View file

@ -0,0 +1,79 @@
This document outlines the format for human-readable IDs within matrix.
Overview
--------
UTF-8 is quickly becoming the standard character encoding set on the web. As
such, Matrix requires that all strings MUST be encoded as UTF-8. However,
using Unicode as the character set for human-readable IDs is troublesome. There
are many different characters which appear identical to each other, but would
identify different users. In addition, there are non-printable characters which
cannot be rendered by the end-user. This opens up a security vulnerability with
phishing/spoofing of IDs, commonly known as a homograph attack.
Web browers encountered this problem when International Domain Names were
introduced. A variety of checks were put in place in order to protect users. If
an address failed the check, the raw punycode would be displayed to disambiguate
the address. Similar checks are performed by home servers in Matrix. However,
Matrix does not use punycode representations, and so does not show raw punycode
on a failed check. Instead, home servers must outright reject these misleading
IDs.
Types of human-readable IDs
---------------------------
There are two main human-readable IDs in question:
- Room aliases
- User IDs
Room aliases look like ``#localpart:domain``. These aliases point to opaque
non human-readable room IDs. These pointers can change, so there is already an
issue present with the same ID pointing to a different destination at a later
date.
User IDs look like ``@localpart:domain``. These represent actual end-users, and
unlike room aliases, there is no layer of indirection. This presents a much
greater concern with homograph attacks.
Checks
------
- Similar to web browsers.
- blacklisted chars (e.g. non-printable characters)
- mix of language sets from 'preferred' language not allowed.
- Language sets from CLDR dataset.
- Treated in segments (localpart, domain)
- Additional restrictions for ease of processing IDs.
- Room alias localparts MUST NOT have ``#`` or ``:``.
- User ID localparts MUST NOT have ``@`` or ``:``.
Rejecting
---------
- Home servers MUST reject room aliases which do not pass the check, both on
GETs and PUTs.
- Home servers MUST reject user ID localparts which do not pass the check, both
on creation and on events.
- Any home server whose domain does not pass this check, MUST use their punycode
domain name instead of the IDN, to prevent other home servers rejecting you.
- Error code is ``M_FAILED_HUMAN_ID_CHECK``. (generic enough for both failing
due to homograph attacks, and failing due to including ``:`` s, etc)
- Error message MAY go into further information about which characters were
rejected and why.
- Error message SHOULD contain a ``failed_keys`` key which contains an array
of strings which represent the keys which failed the check e.g::
failed_keys: [ user_id, room_alias ]
Other considerations
--------------------
- Basic security: Informational key on the event attached by HS to say "unsafe
ID". Problem: clients can just ignore it, and since it will appear only very
rarely, easy to forget when implementing clients.
- Moderate security: Requires client handshake. Forces clients to implement
a check, else they cannot communicate with the misleading ID. However, this is
extra overhead in both client implementations and round-trips.
- High security: Outright rejection of the ID at the point of creation /
receiving event. Point of creation rejection is preferable to avoid the ID
entering the system in the first place. However, malicious HSes can just allow
the ID. Hence, other home servers must reject them if they see them in events.
Client never sees the problem ID, provided the HS is correctly implemented.
- High security decided; client doesn't need to worry about it, no additional
protocol complexity aside from rejection of an event.

View file

@ -418,6 +418,16 @@ which can be set when creating a room:
If this is included, an ``m.room.topic`` event will be sent into the room to indicate the If this is included, an ``m.room.topic`` event will be sent into the room to indicate the
topic for the room. See `Room Events`_ for more information on ``m.room.topic``. topic for the room. See `Room Events`_ for more information on ``m.room.topic``.
``invite``
Type:
List
Optional:
Yes
Value:
A list of user ids to invite.
Description:
This will tell the server to invite everyone in the list to the newly created room.
Example:: Example::
{ {
@ -745,15 +755,17 @@ There are several APIs provided to ``GET`` events for a room:
Description: Description:
Get all ``m.room.member`` state events. Get all ``m.room.member`` state events.
Response format: Response format:
``{ "start": "token", "end": "token", "chunk": [ { m.room.member event }, ... ] }`` ``{ "start": "<token>", "end": "<token>", "chunk": [ { m.room.member event }, ... ] }``
Example: Example:
TODO TODO
|/rooms/<room_id>/messages|_ |/rooms/<room_id>/messages|_
Description: Description:
Get all ``m.room.message`` events. Get all ``m.room.message`` and ``m.room.member`` events. This API supports pagination
using ``from`` and ``to`` query parameters, coupled with the ``start`` and ``end``
tokens from an |initialSync|_ API.
Response format: Response format:
``{ TODO }`` ``{ "start": "<token>", "end": "<token>" }``
Example: Example:
TODO TODO
@ -909,6 +921,22 @@ prefixed with ``m.``
``ban_level`` will be greater than or equal to ``kick_level`` since ``ban_level`` will be greater than or equal to ``kick_level`` since
banning is more severe than kicking. banning is more severe than kicking.
``m.room.aliases``
Summary:
These state events are used to inform the room about what room aliases it has.
Type:
State event
JSON format:
``{ "aliases": ["string", ...] }``
Example:
``{ "aliases": ["#foo:example.com"] }``
Description:
A server `may` inform the room that it has added or removed an alias for
the room. This is purely for informational purposes and may become stale.
Clients `should` check that the room alias is still valid before using it.
The ``state_key`` of the event is the homeserver which owns the room
alias.
``m.room.message`` ``m.room.message``
Summary: Summary:
A message. A message.
@ -1141,8 +1169,14 @@ This event is sent by the caller when they wish to establish a call.
Required keys: Required keys:
- ``call_id`` : "string" - A unique identifier for the call - ``call_id`` : "string" - A unique identifier for the call
- ``offer`` : "offer object" - The session description - ``offer`` : "offer object" - The session description
- ``version`` : "integer" - The version of the VoIP specification this message - ``version`` : "integer" - The version of the VoIP specification this
adheres to. This specification is version 0. message adheres to. This specification is
version 0.
- ``lifetime`` : "integer" - The time in milliseconds that the invite is
valid for. Once the invite age exceeds this
value, clients should discard it. They
should also no longer show the call as
awaiting an answer in the UI.
Optional keys: Optional keys:
None. None.
@ -1154,16 +1188,16 @@ This event is sent by the caller when they wish to establish a call.
- ``type`` : "string" - The type of session description, in this case 'offer' - ``type`` : "string" - The type of session description, in this case 'offer'
- ``sdp`` : "string" - The SDP text of the session description - ``sdp`` : "string" - The SDP text of the session description
``m.call.candidate`` ``m.call.candidates``
This event is sent by callers after sending an invite and by the callee after answering. This event is sent by callers after sending an invite and by the callee after answering.
Its purpose is to give the other party an additional ICE candidate to try using to Its purpose is to give the other party additional ICE candidates to try using to
communicate. communicate.
Required keys: Required keys:
- ``call_id`` : "string" - The ID of the call this event relates to - ``call_id`` : "string" - The ID of the call this event relates to
- ``version`` : "integer" - The version of the VoIP specification this messages - ``version`` : "integer" - The version of the VoIP specification this messages
adheres to. his specification is version 0. adheres to. his specification is version 0.
- ``candidate`` : "candidate object" - Object describing the candidate. - ``candidates`` : "array of candidate objects" - Array of object describing the candidates.
``Candidate Object`` ``Candidate Object``
@ -1221,7 +1255,33 @@ Or a rejected call:
<------- m.call.hangup <------- m.call.hangup
Calls are negotiated according to the WebRTC specification. Calls are negotiated according to the WebRTC specification.
Glare
-----
This specification aims to address the problem of two users calling each other
at roughly the same time and their invites crossing on the wire. It is a far
better experience for the users if their calls are connected if it is clear
that their intention is to set up a call with one another.
In Matrix, calls are to rooms rather than users (even if those rooms may only
contain one other user) so we consider calls which are to the same room.
The rules for dealing with such a situation are as follows:
- If an invite to a room is received whilst the client is preparing to send an
invite to the same room, the client should cancel its outgoing call and
instead automatically accept the incoming call on behalf of the user.
- If an invite to a room is received after the client has sent an invite to the
same room and is waiting for a response, the client should perform a
lexicographical comparison of the call IDs of the two calls and use the
lesser of the two calls, aborting the greater. If the incoming call is the
lesser, the client should accept this call on behalf of the user.
The call setup should appear seamless to the user as if they had simply placed
a call and the other party had accepted. Thusly, any media stream that had been
setup for use on a call should be transferred and used for the call that
replaces it.
Profiles Profiles
======== ========
@ -1251,12 +1311,6 @@ display name other than it being a valid unicode string.
Registration and login Registration and login
====================== ======================
.. WARNING::
The registration API is likely to change.
.. TODO
- TODO Kegan : Make registration like login (just omit the "user" key on the
initial request?)
Clients must register with a home server in order to use Matrix. After Clients must register with a home server in order to use Matrix. After
registering, the client will be given an access token which must be used in ALL registering, the client will be given an access token which must be used in ALL
@ -1269,9 +1323,11 @@ a token sent to their email address, etc. This specification does not define how
home servers should authorise their users who want to login to their existing home servers should authorise their users who want to login to their existing
accounts, but instead defines the standard interface which implementations accounts, but instead defines the standard interface which implementations
should follow so that ANY client can login to ANY home server. Clients login should follow so that ANY client can login to ANY home server. Clients login
using the |login|_ API. using the |login|_ API. Clients register using the |register|_ API. Registration
follows the same procedure as login, but the path requests are sent to are
different.
The login process breaks down into the following: The registration/login process breaks down into the following:
1. Determine the requirements for logging in. 1. Determine the requirements for logging in.
2. Submit the login stage credentials. 2. Submit the login stage credentials.
3. Get credentials or be told the next stage in the login process and repeat 3. Get credentials or be told the next stage in the login process and repeat
@ -1329,7 +1385,7 @@ This specification defines the following login types:
- ``m.login.oauth2`` - ``m.login.oauth2``
- ``m.login.email.code`` - ``m.login.email.code``
- ``m.login.email.url`` - ``m.login.email.url``
- ``m.login.email.identity``
Password-based Password-based
-------------- --------------
@ -1477,6 +1533,31 @@ If the link has not been visited yet, a standard error response with an errcode
``M_LOGIN_EMAIL_URL_NOT_YET`` should be returned. ``M_LOGIN_EMAIL_URL_NOT_YET`` should be returned.
Email-based (identity server)
-----------------------------
:Type:
``m.login.email.identity``
:Description:
Login is supported by authorising an email address with an identity server.
Prior to submitting this, the client should authenticate with an identity server.
After authenticating, the session information should be submitted to the home server.
To respond to this type, reply with::
{
"type": "m.login.email.identity",
"threepidCreds": [
{
"sid": "<identity server session id>",
"clientSecret": "<identity server client secret>",
"idServer": "<url of identity server authed with, e.g. 'matrix.org:8090'>"
}
]
}
N-Factor Authentication N-Factor Authentication
----------------------- -----------------------
Multiple login stages can be combined to create N-factor authentication during login. Multiple login stages can be combined to create N-factor authentication during login.
@ -1946,6 +2027,10 @@ victim would then include in their view of the chatroom history. Other servers
in the chatroom would reject the invalid messages and potentially reject the in the chatroom would reject the invalid messages and potentially reject the
victims messages as well since they depended on the invalid messages. victims messages as well since they depended on the invalid messages.
.. TODO
Track trustworthiness of HS or users based on if they try to pretend they
haven't seen recent events, and fake a splitbrain... --M
Threat: Block Network Traffic Threat: Block Network Traffic
+++++++++++++++++++++++++++++ +++++++++++++++++++++++++++++
@ -2184,6 +2269,9 @@ Transaction:
.. |login| replace:: ``/login`` .. |login| replace:: ``/login``
.. _login: /docs/api/client-server/#!/-login .. _login: /docs/api/client-server/#!/-login
.. |register| replace:: ``/register``
.. _register: /docs/api/client-server/#!/-registration
.. |/rooms/<room_id>/messages| replace:: ``/rooms/<room_id>/messages`` .. |/rooms/<room_id>/messages| replace:: ``/rooms/<room_id>/messages``
.. _/rooms/<room_id>/messages: /docs/api/client-server/#!/-rooms/get_messages .. _/rooms/<room_id>/messages: /docs/api/client-server/#!/-rooms/get_messages

View file

@ -19,7 +19,12 @@ $('.login').live('click', function() {
showLoggedIn(data); showLoggedIn(data);
}, },
error: function(err) { error: function(err) {
alert(JSON.stringify($.parseJSON(err.responseText))); var errMsg = "To try this, you need a home server running!";
var errJson = $.parseJSON(err.responseText);
if (errJson) {
errMsg = JSON.stringify(errJson);
}
alert(errMsg);
} }
}); });
}); });

View file

@ -58,7 +58,12 @@ $('.login').live('click', function() {
showLoggedIn(data); showLoggedIn(data);
}, },
error: function(err) { error: function(err) {
alert(JSON.stringify($.parseJSON(err.responseText))); var errMsg = "To try this, you need a home server running!";
var errJson = $.parseJSON(err.responseText);
if (errJson) {
errMsg = JSON.stringify(errJson);
}
alert(errMsg);
} }
}); });
}); });

View file

@ -0,0 +1,7 @@
name: Example Matrix Client
description: Includes login, live event streaming, creating rooms, sending messages and viewing member lists.
authors:
- matrix.org
resources:
- http://matrix.org
normalize_css: no

View file

@ -20,7 +20,12 @@ $('.register').live('click', function() {
showLoggedIn(data); showLoggedIn(data);
}, },
error: function(err) { error: function(err) {
alert(JSON.stringify($.parseJSON(err.responseText))); var errMsg = "To try this, you need a home server running!";
var errJson = $.parseJSON(err.responseText);
if (errJson) {
errMsg = JSON.stringify(errJson);
}
alert(errMsg);
} }
}); });
}); });
@ -36,7 +41,12 @@ var login = function(user, password) {
showLoggedIn(data); showLoggedIn(data);
}, },
error: function(err) { error: function(err) {
alert(JSON.stringify($.parseJSON(err.responseText))); var errMsg = "To try this, you need a home server running!";
var errJson = $.parseJSON(err.responseText);
if (errJson) {
errMsg = JSON.stringify(errJson);
}
alert(errMsg);
} }
}); });
}; };

View file

@ -28,7 +28,12 @@ $('.login').live('click', function() {
showLoggedIn(data); showLoggedIn(data);
}, },
error: function(err) { error: function(err) {
alert(JSON.stringify($.parseJSON(err.responseText))); var errMsg = "To try this, you need a home server running!";
var errJson = $.parseJSON(err.responseText);
if (errJson) {
errMsg = JSON.stringify(errJson);
}
alert(errMsg);
} }
}); });
}); });

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a synapse home server. """ This is a reference implementation of a synapse home server.
""" """
__version__ = "0.2.1" __version__ = "0.3.3"

View file

@ -18,8 +18,8 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership, JoinRules from synapse.api.constants import Membership, JoinRules
from synapse.api.errors import AuthError, StoreError, Codes from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
from synapse.api.events.room import RoomMemberEvent from synapse.api.events.room import RoomMemberEvent, RoomPowerLevelsEvent
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
import logging import logging
@ -67,6 +67,9 @@ class Auth(object):
else: else:
yield self._can_send_event(event) yield self._can_send_event(event)
if event.type == RoomPowerLevelsEvent.TYPE:
yield self._check_power_levels(event)
defer.returnValue(True) defer.returnValue(True)
else: else:
raise AuthError(500, "Unknown event: %s" % event) raise AuthError(500, "Unknown event: %s" % event)
@ -172,7 +175,7 @@ class Auth(object):
if kick_level: if kick_level:
kick_level = int(kick_level) kick_level = int(kick_level)
else: else:
kick_level = 5 kick_level = 50
if user_level < kick_level: if user_level < kick_level:
raise AuthError( raise AuthError(
@ -189,7 +192,7 @@ class Auth(object):
if ban_level: if ban_level:
ban_level = int(ban_level) ban_level = int(ban_level)
else: else:
ban_level = 5 # FIXME (erikj): What should we do here? ban_level = 50 # FIXME (erikj): What should we do here?
if user_level < ban_level: if user_level < ban_level:
raise AuthError(403, "You don't have permission to ban") raise AuthError(403, "You don't have permission to ban")
@ -305,7 +308,9 @@ class Auth(object):
else: else:
user_level = 0 user_level = 0
logger.debug("Checking power level for %s, %s", event.user_id, user_level) logger.debug(
"Checking power level for %s, %s", event.user_id, user_level
)
if current_state and hasattr(current_state, "required_power_level"): if current_state and hasattr(current_state, "required_power_level"):
req = current_state.required_power_level req = current_state.required_power_level
@ -315,3 +320,101 @@ class Auth(object):
403, 403,
"You don't have permission to change that state" "You don't have permission to change that state"
) )
@defer.inlineCallbacks
def _check_power_levels(self, event):
for k, v in event.content.items():
if k == "default":
continue
# FIXME (erikj): We don't want hsob_Ts in content.
if k == "hsob_ts":
continue
try:
self.hs.parse_userid(k)
except:
raise SynapseError(400, "Not a valid user_id: %s" % (k,))
try:
int(v)
except:
raise SynapseError(400, "Not a valid power level: %s" % (v,))
current_state = yield self.store.get_current_state(
event.room_id,
event.type,
event.state_key,
)
if not current_state:
return
else:
current_state = current_state[0]
user_level = yield self.store.get_power_level(
event.room_id,
event.user_id,
)
if user_level:
user_level = int(user_level)
else:
user_level = 0
old_list = current_state.content
# FIXME (erikj)
old_people = {k: v for k, v in old_list.items() if k.startswith("@")}
new_people = {
k: v for k, v in event.content.items()
if k.startswith("@")
}
removed = set(old_people.keys()) - set(new_people.keys())
added = set(old_people.keys()) - set(new_people.keys())
same = set(old_people.keys()) & set(new_people.keys())
for r in removed:
if int(old_list.content[r]) > user_level:
raise AuthError(
403,
"You don't have permission to remove user: %s" % (r, )
)
for n in added:
if int(event.content[n]) > user_level:
raise AuthError(
403,
"You don't have permission to add ops level greater "
"than your own"
)
for s in same:
if int(event.content[s]) != int(old_list[s]):
if int(event.content[s]) > user_level:
raise AuthError(
403,
"You don't have permission to add ops level greater "
"than your own"
)
if "default" in old_list:
old_default = int(old_list["default"])
if old_default > user_level:
raise AuthError(
403,
"You don't have permission to add ops level greater than "
"your own"
)
if "default" in event.content:
new_default = int(event.content["default"])
if new_default > user_level:
raise AuthError(
403,
"You don't have permission to add ops level greater "
"than your own"
)

View file

@ -50,3 +50,12 @@ class JoinRules(object):
KNOCK = u"knock" KNOCK = u"knock"
INVITE = u"invite" INVITE = u"invite"
PRIVATE = u"private" PRIVATE = u"private"
class LoginType(object):
PASSWORD = u"m.login.password"
OAUTH = u"m.login.oauth2"
EMAIL_CODE = u"m.login.email.code"
EMAIL_URL = u"m.login.email.url"
EMAIL_IDENTITY = u"m.login.email.identity"
RECAPTCHA = u"m.login.recaptcha"

View file

@ -29,6 +29,8 @@ class Codes(object):
NOT_FOUND = "M_NOT_FOUND" NOT_FOUND = "M_NOT_FOUND"
UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN" UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN"
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED" LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
class CodeMessageException(Exception): class CodeMessageException(Exception):
@ -101,6 +103,19 @@ class StoreError(SynapseError):
pass pass
class InvalidCaptchaError(SynapseError):
def __init__(self, code=400, msg="Invalid captcha.", error_url=None,
errcode=Codes.CAPTCHA_INVALID):
super(InvalidCaptchaError, self).__init__(code, msg, errcode)
self.error_url = error_url
def error_dict(self):
return cs_error(
self.msg,
self.errcode,
error_url=self.error_url,
)
class LimitExceededError(SynapseError): class LimitExceededError(SynapseError):
"""A client has sent too many requests and is being throttled. """A client has sent too many requests and is being throttled.
""" """

View file

@ -17,6 +17,19 @@ from synapse.api.errors import SynapseError, Codes
from synapse.util.jsonobject import JsonEncodedObject from synapse.util.jsonobject import JsonEncodedObject
def serialize_event(hs, e):
# FIXME(erikj): To handle the case of presence events and the like
if not isinstance(e, SynapseEvent):
return e
d = e.get_dict()
if "age_ts" in d:
d["age"] = int(hs.get_clock().time_msec()) - d["age_ts"]
del d["age_ts"]
return d
class SynapseEvent(JsonEncodedObject): class SynapseEvent(JsonEncodedObject):
"""Base class for Synapse events. These are JSON objects which must abide """Base class for Synapse events. These are JSON objects which must abide
@ -43,6 +56,8 @@ class SynapseEvent(JsonEncodedObject):
"content", # HTTP body, JSON "content", # HTTP body, JSON
"state_key", "state_key",
"required_power_level", "required_power_level",
"age_ts",
"prev_content",
] ]
internal_keys = [ internal_keys = [
@ -141,7 +156,8 @@ class SynapseEvent(JsonEncodedObject):
return "Missing %s key" % key return "Missing %s key" % key
if type(content[key]) != type(template[key]): if type(content[key]) != type(template[key]):
return "Key %s is of the wrong type." % key return "Key %s is of the wrong type (got %s, want %s)" % (
key, type(content[key]), type(template[key]))
if type(content[key]) == dict: if type(content[key]) == dict:
# we must go deeper # we must go deeper
@ -157,7 +173,8 @@ class SynapseEvent(JsonEncodedObject):
class SynapseStateEvent(SynapseEvent): class SynapseStateEvent(SynapseEvent):
def __init__(self, **kwargs):
def __init__(self, **kwargs):
if "state_key" not in kwargs: if "state_key" not in kwargs:
kwargs["state_key"] = "" kwargs["state_key"] = ""
super(SynapseStateEvent, self).__init__(**kwargs) super(SynapseStateEvent, self).__init__(**kwargs)

View file

@ -47,15 +47,26 @@ class EventFactory(object):
self._event_list[event_class.TYPE] = event_class self._event_list[event_class.TYPE] = event_class
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.hs = hs
def create_event(self, etype=None, **kwargs): def create_event(self, etype=None, **kwargs):
kwargs["type"] = etype kwargs["type"] = etype
if "event_id" not in kwargs: if "event_id" not in kwargs:
kwargs["event_id"] = random_string(10) kwargs["event_id"] = "%s@%s" % (
random_string(10), self.hs.hostname
)
if "ts" not in kwargs: if "ts" not in kwargs:
kwargs["ts"] = int(self.clock.time_msec()) kwargs["ts"] = int(self.clock.time_msec())
# The "age" key is a delta timestamp that should be converted into an
# absolute timestamp the minute we see it.
if "age" in kwargs:
kwargs["age_ts"] = int(self.clock.time_msec()) - int(kwargs["age"])
del kwargs["age"]
elif "age_ts" not in kwargs:
kwargs["age_ts"] = int(self.clock.time_msec())
if etype in self._event_list: if etype in self._event_list:
handler = self._event_list[etype] handler = self._event_list[etype]
else: else:

View file

@ -173,3 +173,10 @@ class RoomOpsPowerLevelsEvent(SynapseStateEvent):
def get_content_template(self): def get_content_template(self):
return {} return {}
class RoomAliasesEvent(SynapseStateEvent):
TYPE = "m.room.aliases"
def get_content_template(self):
return {}

View file

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage import read_schema from synapse.storage import prepare_database
from synapse.server import HomeServer from synapse.server import HomeServer
@ -36,30 +36,14 @@ from daemonize import Daemonize
import twisted.manhole.telnet import twisted.manhole.telnet
import logging import logging
import sqlite3
import os import os
import re import re
import sys import sys
import sqlite3
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SCHEMAS = [
"transactions",
"pdu",
"users",
"profiles",
"presence",
"im",
"room_aliases",
]
# Remember to update this number every time an incompatible change is made to
# database schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 2
class SynapseHomeServer(HomeServer): class SynapseHomeServer(HomeServer):
def build_http_client(self): def build_http_client(self):
@ -80,52 +64,12 @@ class SynapseHomeServer(HomeServer):
) )
def build_db_pool(self): def build_db_pool(self):
""" Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we return adbapi.ConnectionPool(
don't have to worry about overwriting existing content. "sqlite3", self.get_db_name(),
""" check_same_thread=False,
logging.info("Preparing database: %s...", self.db_name) cp_min=1,
cp_max=1
with sqlite3.connect(self.db_name) as db_conn: )
c = db_conn.cursor()
c.execute("PRAGMA user_version")
row = c.fetchone()
if row and row[0]:
user_version = row[0]
if user_version > SCHEMA_VERSION:
raise ValueError("Cannot use this database as it is too " +
"new for the server to understand"
)
elif user_version < SCHEMA_VERSION:
logging.info("Upgrading database from version %d",
user_version
)
# Run every version since after the current version.
for v in range(user_version + 1, SCHEMA_VERSION + 1):
sql_script = read_schema("delta/v%d" % (v))
c.executescript(sql_script)
db_conn.commit()
else:
for sql_loc in SCHEMAS:
sql_script = read_schema(sql_loc)
c.executescript(sql_script)
db_conn.commit()
c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
c.close()
logging.info("Database prepared in %s.", self.db_name)
pool = adbapi.ConnectionPool(
'sqlite3', self.db_name, check_same_thread=False,
cp_min=1, cp_max=1)
return pool
def create_resource_tree(self, web_client, redirect_root_to_web_client): def create_resource_tree(self, web_client, redirect_root_to_web_client):
"""Create the resource tree for this Home Server. """Create the resource tree for this Home Server.
@ -230,10 +174,6 @@ class SynapseHomeServer(HomeServer):
logger.info("Synapse now listening on port %d", unsecure_port) logger.info("Synapse now listening on port %d", unsecure_port)
def run():
reactor.run()
def setup(): def setup():
config = HomeServerConfig.load_config( config = HomeServerConfig.load_config(
"Synapse Homeserver", "Synapse Homeserver",
@ -268,7 +208,15 @@ def setup():
web_client=config.webclient, web_client=config.webclient,
redirect_root_to_web_client=True, redirect_root_to_web_client=True,
) )
hs.start_listening(config.bind_port, config.unsecure_port)
db_name = hs.get_db_name()
logging.info("Preparing database: %s...", db_name)
with sqlite3.connect(db_name) as db_conn:
prepare_database(db_conn)
logging.info("Database prepared in %s.", db_name)
hs.get_db_pool() hs.get_db_pool()
@ -279,12 +227,14 @@ def setup():
f.namespace['hs'] = hs f.namespace['hs'] = hs
reactor.listenTCP(config.manhole, f, interface='127.0.0.1') reactor.listenTCP(config.manhole, f, interface='127.0.0.1')
hs.start_listening(config.bind_port, config.unsecure_port)
if config.daemonize: if config.daemonize:
print config.pid_file print config.pid_file
daemon = Daemonize( daemon = Daemonize(
app="synapse-homeserver", app="synapse-homeserver",
pid=config.pid_file, pid=config.pid_file,
action=run, action=reactor.run,
auto_close_fds=False, auto_close_fds=False,
verbose=True, verbose=True,
logger=logger, logger=logger,
@ -292,7 +242,7 @@ def setup():
daemon.start() daemon.start()
else: else:
run() reactor.run()
if __name__ == '__main__': if __name__ == '__main__':

46
synapse/config/captcha.py Normal file
View file

@ -0,0 +1,46 @@
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class CaptchaConfig(Config):
def __init__(self, args):
super(CaptchaConfig, self).__init__(args)
self.recaptcha_private_key = args.recaptcha_private_key
self.enable_registration_captcha = args.enable_registration_captcha
self.captcha_ip_origin_is_x_forwarded = (
args.captcha_ip_origin_is_x_forwarded
)
@classmethod
def add_arguments(cls, parser):
super(CaptchaConfig, cls).add_arguments(parser)
group = parser.add_argument_group("recaptcha")
group.add_argument(
"--recaptcha-private-key", type=str, default="YOUR_PRIVATE_KEY",
help="The matching private key for the web client's public key."
)
group.add_argument(
"--enable-registration-captcha", type=bool, default=False,
help="Enables ReCaptcha checks when registering, preventing signup"
+ " unless a captcha is answered. Requires a valid ReCaptcha "
+ "public/private key."
)
group.add_argument(
"--captcha_ip_origin_is_x_forwarded", type=bool, default=False,
help="When checking captchas, use the X-Forwarded-For (XFF) header"
+ " as the client IP and not the actual client IP."
)

39
synapse/config/email.py Normal file
View file

@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class EmailConfig(Config):
def __init__(self, args):
super(EmailConfig, self).__init__(args)
self.email_from_address = args.email_from_address
self.email_smtp_server = args.email_smtp_server
@classmethod
def add_arguments(cls, parser):
super(EmailConfig, cls).add_arguments(parser)
email_group = parser.add_argument_group("email")
email_group.add_argument(
"--email-from-address",
default="FROM@EXAMPLE.COM",
help="The address to send emails from (e.g. for password resets)."
)
email_group.add_argument(
"--email-smtp-server",
default="",
help="The SMTP server to send emails from (e.g. for password resets)."
)

View file

@ -19,11 +19,16 @@ from .logger import LoggingConfig
from .database import DatabaseConfig from .database import DatabaseConfig
from .ratelimiting import RatelimitConfig from .ratelimiting import RatelimitConfig
from .repository import ContentRepositoryConfig from .repository import ContentRepositoryConfig
from .captcha import CaptchaConfig
from .email import EmailConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig): RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
EmailConfig):
pass pass
if __name__=='__main__':
if __name__ == '__main__':
import sys import sys
HomeServerConfig.load_config("Generate config", sys.argv[1:], "HomeServer") HomeServerConfig.load_config("Generate config", sys.argv[1:], "HomeServer")

View file

@ -291,6 +291,13 @@ class ReplicationLayer(object):
def on_incoming_transaction(self, transaction_data): def on_incoming_transaction(self, transaction_data):
transaction = Transaction(**transaction_data) transaction = Transaction(**transaction_data)
for p in transaction.pdus:
if "age" in p:
p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
del p["age"]
pdu_list = [Pdu(**p) for p in transaction.pdus]
logger.debug("[%s] Got transaction", transaction.transaction_id) logger.debug("[%s] Got transaction", transaction.transaction_id)
response = yield self.transaction_actions.have_responded(transaction) response = yield self.transaction_actions.have_responded(transaction)
@ -303,8 +310,6 @@ class ReplicationLayer(object):
logger.debug("[%s] Transacition is new", transaction.transaction_id) logger.debug("[%s] Transacition is new", transaction.transaction_id)
pdu_list = [Pdu(**p) for p in transaction.pdus]
dl = [] dl = []
for pdu in pdu_list: for pdu in pdu_list:
dl.append(self._handle_new_pdu(pdu)) dl.append(self._handle_new_pdu(pdu))
@ -405,9 +410,14 @@ class ReplicationLayer(object):
"""Returns a new Transaction containing the given PDUs suitable for """Returns a new Transaction containing the given PDUs suitable for
transmission. transmission.
""" """
pdus = [p.get_dict() for p in pdu_list]
for p in pdus:
if "age_ts" in pdus:
p["age"] = int(self.clock.time_msec()) - p["age_ts"]
return Transaction( return Transaction(
pdus=[p.get_dict() for p in pdu_list],
origin=self.server_name, origin=self.server_name,
pdus=pdus,
ts=int(self._clock.time_msec()), ts=int(self._clock.time_msec()),
destination=None, destination=None,
) )
@ -593,8 +603,21 @@ class _TransactionQueue(object):
logger.debug("TX [%s] Sending transaction...", destination) logger.debug("TX [%s] Sending transaction...", destination)
# Actually send the transaction # Actually send the transaction
# FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work
def cb(transaction):
now = int(self._clock.time_msec())
if "pdus" in transaction:
for p in transaction["pdus"]:
if "age_ts" in p:
p["age"] = now - int(p["age_ts"])
return transaction
code, response = yield self.transport_layer.send_transaction( code, response = yield self.transport_layer.send_transaction(
transaction transaction,
on_send_callback=cb,
) )
logger.debug("TX [%s] Sent transaction", destination) logger.debug("TX [%s] Sent transaction", destination)

View file

@ -144,7 +144,7 @@ class TransportLayer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def send_transaction(self, transaction): def send_transaction(self, transaction, on_send_callback=None):
""" Sends the given Transaction to it's destination """ Sends the given Transaction to it's destination
Args: Args:
@ -165,10 +165,23 @@ class TransportLayer(object):
data = transaction.get_dict() data = transaction.get_dict()
# FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work
def cb(destination, method, path_bytes, producer):
if not on_send_callback:
return
transaction = json.loads(producer.body)
new_transaction = on_send_callback(transaction)
producer.reset(new_transaction)
code, response = yield self.client.put_json( code, response = yield self.client.put_json(
transaction.destination, transaction.destination,
path=PREFIX + "/send/%s/" % transaction.transaction_id, path=PREFIX + "/send/%s/" % transaction.transaction_id,
data=data data=data,
on_send_callback=cb,
) )
logger.debug( logger.debug(

View file

@ -69,6 +69,7 @@ class Pdu(JsonEncodedObject):
"prev_state_id", "prev_state_id",
"prev_state_origin", "prev_state_origin",
"required_power_level", "required_power_level",
"user_id",
] ]
internal_keys = [ internal_keys = [

View file

@ -42,9 +42,6 @@ class BaseHandler(object):
retry_after_ms=int(1000*(time_allowed - time_now)), retry_after_ms=int(1000*(time_allowed - time_now)),
) )
class BaseRoomHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _on_new_room_event(self, event, snapshot, extra_destinations=[], def _on_new_room_event(self, event, snapshot, extra_destinations=[],
extra_users=[]): extra_users=[]):

View file

@ -19,8 +19,10 @@ from ._base import BaseHandler
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.client import HttpClient from synapse.http.client import HttpClient
from synapse.api.events.room import RoomAliasesEvent
import logging import logging
import sqlite3
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -37,7 +39,8 @@ class DirectoryHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def create_association(self, room_alias, room_id, servers=None): def create_association(self, user_id, room_alias, room_id, servers=None):
# TODO(erikj): Do auth. # TODO(erikj): Do auth.
if not room_alias.is_mine: if not room_alias.is_mine:
@ -54,12 +57,37 @@ class DirectoryHandler(BaseHandler):
if not servers: if not servers:
raise SynapseError(400, "Failed to get server list") raise SynapseError(400, "Failed to get server list")
yield self.store.create_room_alias_association(
room_alias, try:
room_id, yield self.store.create_room_alias_association(
servers room_alias,
room_id,
servers
)
except sqlite3.IntegrityError:
defer.returnValue("Already exists")
# TODO: Send the room event.
aliases = yield self.store.get_aliases_for_room(room_id)
event = self.event_factory.create_event(
etype=RoomAliasesEvent.TYPE,
state_key=self.hs.hostname,
room_id=room_id,
user_id=user_id,
content={"aliases": aliases},
) )
snapshot = yield self.store.snapshot_room(
room_id=room_id,
user_id=user_id,
)
yield self.state_handler.handle_new_event(event, snapshot)
yield self._on_new_room_event(event, snapshot, extra_users=[user_id])
@defer.inlineCallbacks @defer.inlineCallbacks
def get_association(self, room_alias): def get_association(self, room_alias):
room_id = None room_id = None

View file

@ -15,7 +15,6 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.events import SynapseEvent
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from ._base import BaseHandler from ._base import BaseHandler
@ -71,10 +70,7 @@ class EventStreamHandler(BaseHandler):
auth_user, room_ids, pagin_config, timeout auth_user, room_ids, pagin_config, timeout
) )
chunks = [ chunks = [self.hs.serialize_event(e) for e in events]
e.get_dict() if isinstance(e, SynapseEvent) else e
for e in events
]
chunk = { chunk = {
"chunk": chunks, "chunk": chunks,
@ -92,7 +88,9 @@ class EventStreamHandler(BaseHandler):
# 10 seconds of grace to allow the client to reconnect again # 10 seconds of grace to allow the client to reconnect again
# before we think they're gone # before we think they're gone
def _later(): def _later():
logger.debug("_later stopped_user_eventstream %s", auth_user) logger.debug(
"_later stopped_user_eventstream %s", auth_user
)
self.distributor.fire( self.distributor.fire(
"stopped_user_eventstream", auth_user "stopped_user_eventstream", auth_user
) )

View file

@ -93,22 +93,18 @@ class FederationHandler(BaseHandler):
""" """
event = self.pdu_codec.event_from_pdu(pdu) event = self.pdu_codec.event_from_pdu(pdu)
logger.debug("Got event: %s", event.event_id)
with (yield self.lock_manager.lock(pdu.context)): with (yield self.lock_manager.lock(pdu.context)):
if event.is_state and not backfilled: if event.is_state and not backfilled:
is_new_state = yield self.state_handler.handle_new_state( is_new_state = yield self.state_handler.handle_new_state(
pdu pdu
) )
if not is_new_state:
return
else: else:
is_new_state = False is_new_state = False
# TODO: Implement something in federation that allows us to # TODO: Implement something in federation that allows us to
# respond to PDU. # respond to PDU.
if hasattr(event, "state_key") and not is_new_state:
logger.debug("Ignoring old state.")
return
target_is_mine = False target_is_mine = False
if hasattr(event, "target_host"): if hasattr(event, "target_host"):
target_is_mine = event.target_host == self.hs.hostname target_is_mine = event.target_host == self.hs.hostname
@ -139,7 +135,11 @@ class FederationHandler(BaseHandler):
else: else:
with (yield self.room_lock.lock(event.room_id)): with (yield self.room_lock.lock(event.room_id)):
yield self.store.persist_event(event, backfilled) yield self.store.persist_event(
event,
backfilled,
is_new_state=is_new_state
)
room = yield self.store.get_room(event.room_id) room = yield self.store.get_room(event.room_id)

View file

@ -17,9 +17,13 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.errors import LoginError, Codes from synapse.api.errors import LoginError, Codes
from synapse.http.client import PlainHttpClient
from synapse.util.emailutils import EmailException
import synapse.util.emailutils as emailutils
import bcrypt import bcrypt
import logging import logging
import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -62,4 +66,41 @@ class LoginHandler(BaseHandler):
defer.returnValue(token) defer.returnValue(token)
else: else:
logger.warn("Failed password login for user %s", user) logger.warn("Failed password login for user %s", user)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def reset_password(self, user_id, email):
is_valid = yield self._check_valid_association(user_id, email)
logger.info("reset_password user=%s email=%s valid=%s", user_id, email,
is_valid)
if is_valid:
try:
# send an email out
emailutils.send_email(
smtp_server=self.hs.config.email_smtp_server,
from_addr=self.hs.config.email_from_address,
to_addr=email,
subject="Password Reset",
body="TODO."
)
except EmailException as e:
logger.exception(e)
@defer.inlineCallbacks
def _check_valid_association(self, user_id, email):
identity = yield self._query_email(email)
if identity and "mxid" in identity:
if identity["mxid"] == user_id:
defer.returnValue(True)
return
defer.returnValue(False)
@defer.inlineCallbacks
def _query_email(self, email):
httpCli = PlainHttpClient(self.hs)
data = yield httpCli.get_json(
'matrix.org:8090', # TODO FIXME This should be configurable.
"/_matrix/identity/api/v1/lookup?medium=email&address=" +
"%s" % urllib.quote(email)
)
defer.returnValue(data)

View file

@ -19,7 +19,7 @@ from synapse.api.constants import Membership
from synapse.api.events.room import RoomTopicEvent from synapse.api.events.room import RoomTopicEvent
from synapse.api.errors import RoomError from synapse.api.errors import RoomError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from ._base import BaseRoomHandler from ._base import BaseHandler
import logging import logging
@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
class MessageHandler(BaseRoomHandler): class MessageHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(MessageHandler, self).__init__(hs) super(MessageHandler, self).__init__(hs)
@ -124,7 +124,7 @@ class MessageHandler(BaseRoomHandler):
) )
chunk = { chunk = {
"chunk": [e.get_dict() for e in events], "chunk": [self.hs.serialize_event(e) for e in events],
"start": pagin_config.from_token.to_string(), "start": pagin_config.from_token.to_string(),
"end": next_token.to_string(), "end": next_token.to_string(),
} }
@ -268,6 +268,9 @@ class MessageHandler(BaseRoomHandler):
user, pagination_config, None user, pagination_config, None
) )
public_rooms = yield self.store.get_rooms(is_public=True)
public_room_ids = [r["room_id"] for r in public_rooms]
limit = pagin_config.limit limit = pagin_config.limit
if not limit: if not limit:
limit = 10 limit = 10
@ -276,6 +279,8 @@ class MessageHandler(BaseRoomHandler):
d = { d = {
"room_id": event.room_id, "room_id": event.room_id,
"membership": event.membership, "membership": event.membership,
"visibility": ("public" if event.room_id in
public_room_ids else "private"),
} }
if event.membership == Membership.INVITE: if event.membership == Membership.INVITE:
@ -296,7 +301,7 @@ class MessageHandler(BaseRoomHandler):
end_token = now_token.copy_and_replace("room_key", token[1]) end_token = now_token.copy_and_replace("room_key", token[1])
d["messages"] = { d["messages"] = {
"chunk": [m.get_dict() for m in messages], "chunk": [self.hs.serialize_event(m) for m in messages],
"start": start_token.to_string(), "start": start_token.to_string(),
"end": end_token.to_string(), "end": end_token.to_string(),
} }
@ -304,7 +309,7 @@ class MessageHandler(BaseRoomHandler):
current_state = yield self.store.get_current_state( current_state = yield self.store.get_current_state(
event.room_id event.room_id
) )
d["state"] = [c.get_dict() for c in current_state] d["state"] = [self.hs.serialize_event(c) for c in current_state]
except: except:
logger.exception("Failed to get snapshot") logger.exception("Failed to get snapshot")

View file

@ -796,11 +796,12 @@ class PresenceEventSource(object):
updates = [] updates = []
# TODO(paul): use a DeferredList ? How to limit concurrency. # TODO(paul): use a DeferredList ? How to limit concurrency.
for observed_user in cachemap.keys(): for observed_user in cachemap.keys():
if not (from_key < cachemap[observed_user].serial): cached = cachemap[observed_user]
if not (from_key < cached.serial):
continue continue
if (yield self.is_visible(observer_user, observed_user)): if (yield self.is_visible(observer_user, observed_user)):
updates.append((observed_user, cachemap[observed_user])) updates.append((observed_user, cached))
# TODO(paul): limit # TODO(paul): limit

View file

@ -15,9 +15,9 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.api.constants import Membership
from synapse.api.errors import CodeMessageException from synapse.api.events.room import RoomMemberEvent
from ._base import BaseHandler from ._base import BaseHandler
@ -97,6 +97,8 @@ class ProfileHandler(BaseHandler):
} }
) )
yield self._update_join_states(target_user)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_avatar_url(self, target_user): def get_avatar_url(self, target_user):
if target_user.is_mine: if target_user.is_mine:
@ -144,6 +146,8 @@ class ProfileHandler(BaseHandler):
} }
) )
yield self._update_join_states(target_user)
@defer.inlineCallbacks @defer.inlineCallbacks
def collect_presencelike_data(self, user, state): def collect_presencelike_data(self, user, state):
if not user.is_mine: if not user.is_mine:
@ -180,3 +184,39 @@ class ProfileHandler(BaseHandler):
) )
defer.returnValue(response) defer.returnValue(response)
@defer.inlineCallbacks
def _update_join_states(self, user):
if not user.is_mine:
return
joins = yield self.store.get_rooms_for_user_where_membership_is(
user.to_string(),
[Membership.JOIN],
)
for j in joins:
snapshot = yield self.store.snapshot_room(
j.room_id, j.state_key, RoomMemberEvent.TYPE,
j.state_key
)
content = {
"membership": j.content["membership"],
"prev": j.content["membership"],
}
yield self.distributor.fire(
"collect_presencelike_data", user, content
)
new_event = self.event_factory.create_event(
etype=j.type,
room_id=j.room_id,
state_key=j.state_key,
content=content,
user_id=j.state_key,
)
yield self.state_handler.handle_new_event(new_event, snapshot)
yield self._on_new_room_event(new_event, snapshot)

View file

@ -17,7 +17,9 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.types import UserID from synapse.types import UserID
from synapse.api.errors import SynapseError, RegistrationError from synapse.api.errors import (
SynapseError, RegistrationError, InvalidCaptchaError
)
from ._base import BaseHandler from ._base import BaseHandler
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
from synapse.http.client import PlainHttpClient from synapse.http.client import PlainHttpClient
@ -38,7 +40,7 @@ class RegistrationHandler(BaseHandler):
self.distributor.declare("registered_user") self.distributor.declare("registered_user")
@defer.inlineCallbacks @defer.inlineCallbacks
def register(self, localpart=None, password=None, threepidCreds=None): def register(self, localpart=None, password=None):
"""Registers a new client on the server. """Registers a new client on the server.
Args: Args:
@ -51,20 +53,6 @@ class RegistrationHandler(BaseHandler):
Raises: Raises:
RegistrationError if there was a problem registering. RegistrationError if there was a problem registering.
""" """
if threepidCreds:
for c in threepidCreds:
logger.info("validating theeepidcred sid %s on id server %s", c['sid'], c['idServer'])
try:
threepid = yield self._threepid_from_creds(c)
except:
logger.err()
raise RegistrationError(400, "Couldn't validate 3pid")
if not threepid:
raise RegistrationError(400, "Couldn't validate 3pid")
logger.info("got threepid medium %s address %s", threepid['medium'], threepid['address'])
password_hash = None password_hash = None
if password: if password:
password_hash = bcrypt.hashpw(password, bcrypt.gensalt()) password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
@ -106,15 +94,54 @@ class RegistrationHandler(BaseHandler):
raise RegistrationError( raise RegistrationError(
500, "Cannot generate user ID.") 500, "Cannot generate user ID.")
# Now we have a matrix ID, bind it to the threepids we were given
if threepidCreds:
for c in threepidCreds:
# XXX: This should be a deferred list, shouldn't it?
yield self._bind_threepid(c, user_id)
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))
@defer.inlineCallbacks
def check_recaptcha(self, ip, private_key, challenge, response):
"""Checks a recaptcha is correct."""
captcha_response = yield self._validate_captcha(
ip,
private_key,
challenge,
response
)
if not captcha_response["valid"]:
logger.info("Invalid captcha entered from %s. Error: %s",
ip, captcha_response["error_url"])
raise InvalidCaptchaError(
error_url=captcha_response["error_url"]
)
else:
logger.info("Valid captcha entered from %s", ip)
@defer.inlineCallbacks
def register_email(self, threepidCreds):
"""Registers emails with an identity server."""
for c in threepidCreds:
logger.info("validating theeepidcred sid %s on id server %s",
c['sid'], c['idServer'])
try:
threepid = yield self._threepid_from_creds(c)
except:
logger.err()
raise RegistrationError(400, "Couldn't validate 3pid")
if not threepid:
raise RegistrationError(400, "Couldn't validate 3pid")
logger.info("got threepid medium %s address %s",
threepid['medium'], threepid['address'])
@defer.inlineCallbacks
def bind_emails(self, user_id, threepidCreds):
"""Links emails with a user ID and informs an identity server."""
# Now we have a matrix ID, bind it to the threepids we were given
for c in threepidCreds:
# XXX: This should be a deferred list, shouldn't it?
yield self._bind_threepid(c, user_id)
def _generate_token(self, user_id): def _generate_token(self, user_id):
# urlsafe variant uses _ and - so use . as the separator and replace # urlsafe variant uses _ and - so use . as the separator and replace
# all =s with .s so http clients don't quote =s when it is used as # all =s with .s so http clients don't quote =s when it is used as
@ -129,16 +156,17 @@ class RegistrationHandler(BaseHandler):
def _threepid_from_creds(self, creds): def _threepid_from_creds(self, creds):
httpCli = PlainHttpClient(self.hs) httpCli = PlainHttpClient(self.hs)
# XXX: make this configurable! # XXX: make this configurable!
trustedIdServers = [ 'matrix.org:8090' ] trustedIdServers = ['matrix.org:8090']
if not creds['idServer'] in trustedIdServers: if not creds['idServer'] in trustedIdServers:
logger.warn('%s is not a trusted ID server: rejecting 3pid credentials', creds['idServer']) logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
'credentials', creds['idServer'])
defer.returnValue(None) defer.returnValue(None)
data = yield httpCli.get_json( data = yield httpCli.get_json(
creds['idServer'], creds['idServer'],
"/_matrix/identity/api/v1/3pid/getValidated3pid", "/_matrix/identity/api/v1/3pid/getValidated3pid",
{ 'sid': creds['sid'], 'clientSecret': creds['clientSecret'] } {'sid': creds['sid'], 'clientSecret': creds['clientSecret']}
) )
if 'medium' in data: if 'medium' in data:
defer.returnValue(data) defer.returnValue(data)
defer.returnValue(None) defer.returnValue(None)
@ -149,9 +177,45 @@ class RegistrationHandler(BaseHandler):
data = yield httpCli.post_urlencoded_get_json( data = yield httpCli.post_urlencoded_get_json(
creds['idServer'], creds['idServer'],
"/_matrix/identity/api/v1/3pid/bind", "/_matrix/identity/api/v1/3pid/bind",
{ 'sid': creds['sid'], 'clientSecret': creds['clientSecret'], 'mxid':mxid } {'sid': creds['sid'], 'clientSecret': creds['clientSecret'],
'mxid': mxid}
) )
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks
def _validate_captcha(self, ip_addr, private_key, challenge, response):
"""Validates the captcha provided.
Returns:
dict: Containing 'valid'(bool) and 'error_url'(str) if invalid.
"""
response = yield self._submit_captcha(ip_addr, private_key, challenge,
response)
# parse Google's response. Lovely format..
lines = response.split('\n')
json = {
"valid": lines[0] == 'true',
"error_url": "http://www.google.com/recaptcha/api/challenge?" +
"error=%s" % lines[1]
}
defer.returnValue(json)
@defer.inlineCallbacks
def _submit_captcha(self, ip_addr, private_key, challenge, response):
client = PlainHttpClient(self.hs)
data = yield client.post_urlencoded_get_raw(
"www.google.com:80",
"/recaptcha/api/verify",
# twisted dislikes google's response, no content length.
accept_partial=True,
args={
'privatekey': private_key,
'remoteip': ip_addr,
'challenge': challenge,
'response': response
}
)
defer.returnValue(data)

View file

@ -25,14 +25,14 @@ from synapse.api.events.room import (
RoomSendEventLevelEvent, RoomOpsPowerLevelsEvent, RoomNameEvent, RoomSendEventLevelEvent, RoomOpsPowerLevelsEvent, RoomNameEvent,
) )
from synapse.util import stringutils from synapse.util import stringutils
from ._base import BaseRoomHandler from ._base import BaseHandler
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RoomCreationHandler(BaseRoomHandler): class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def create_room(self, user_id, room_id, config): def create_room(self, user_id, room_id, config):
@ -65,6 +65,13 @@ class RoomCreationHandler(BaseRoomHandler):
else: else:
room_alias = None room_alias = None
invite_list = config.get("invite", [])
for i in invite_list:
try:
self.hs.parse_userid(i)
except:
raise SynapseError(400, "Invalid user_id: %s" % (i,))
is_public = config.get("visibility", None) == "public" is_public = config.get("visibility", None) == "public"
if room_id: if room_id:
@ -105,7 +112,9 @@ class RoomCreationHandler(BaseRoomHandler):
) )
if room_alias: if room_alias:
yield self.store.create_room_alias_association( directory_handler = self.hs.get_handlers().directory_handler
yield directory_handler.create_association(
user_id=user_id,
room_id=room_id, room_id=room_id,
room_alias=room_alias, room_alias=room_alias,
servers=[self.hs.hostname], servers=[self.hs.hostname],
@ -132,7 +141,7 @@ class RoomCreationHandler(BaseRoomHandler):
etype=RoomNameEvent.TYPE, etype=RoomNameEvent.TYPE,
room_id=room_id, room_id=room_id,
user_id=user_id, user_id=user_id,
required_power_level=5, required_power_level=50,
content={"name": name}, content={"name": name},
) )
@ -143,7 +152,7 @@ class RoomCreationHandler(BaseRoomHandler):
etype=RoomNameEvent.TYPE, etype=RoomNameEvent.TYPE,
room_id=room_id, room_id=room_id,
user_id=user_id, user_id=user_id,
required_power_level=5, required_power_level=50,
content={"name": name}, content={"name": name},
) )
@ -155,7 +164,7 @@ class RoomCreationHandler(BaseRoomHandler):
etype=RoomTopicEvent.TYPE, etype=RoomTopicEvent.TYPE,
room_id=room_id, room_id=room_id,
user_id=user_id, user_id=user_id,
required_power_level=5, required_power_level=50,
content={"topic": topic}, content={"topic": topic},
) )
@ -176,6 +185,25 @@ class RoomCreationHandler(BaseRoomHandler):
do_auth=False do_auth=False
) )
content = {"membership": Membership.INVITE}
for invitee in invite_list:
invite_event = self.event_factory.create_event(
etype=RoomMemberEvent.TYPE,
state_key=invitee,
room_id=room_id,
user_id=user_id,
content=content
)
yield self.hs.get_handlers().room_member_handler.change_membership(
invite_event,
do_auth=False
)
yield self.hs.get_handlers().room_member_handler.change_membership(
join_event,
do_auth=False
)
result = {"room_id": room_id} result = {"room_id": room_id}
if room_alias: if room_alias:
result["room_alias"] = room_alias.to_string() result["room_alias"] = room_alias.to_string()
@ -186,7 +214,7 @@ class RoomCreationHandler(BaseRoomHandler):
event_keys = { event_keys = {
"room_id": room_id, "room_id": room_id,
"user_id": creator.to_string(), "user_id": creator.to_string(),
"required_power_level": 10, "required_power_level": 100,
} }
def create(etype, **content): def create(etype, **content):
@ -203,7 +231,7 @@ class RoomCreationHandler(BaseRoomHandler):
power_levels_event = self.event_factory.create_event( power_levels_event = self.event_factory.create_event(
etype=RoomPowerLevelsEvent.TYPE, etype=RoomPowerLevelsEvent.TYPE,
content={creator.to_string(): 10, "default": 0}, content={creator.to_string(): 100, "default": 0},
**event_keys **event_keys
) )
@ -215,7 +243,7 @@ class RoomCreationHandler(BaseRoomHandler):
add_state_event = create( add_state_event = create(
etype=RoomAddStateLevelEvent.TYPE, etype=RoomAddStateLevelEvent.TYPE,
level=10, level=100,
) )
send_event = create( send_event = create(
@ -225,8 +253,8 @@ class RoomCreationHandler(BaseRoomHandler):
ops = create( ops = create(
etype=RoomOpsPowerLevelsEvent.TYPE, etype=RoomOpsPowerLevelsEvent.TYPE,
ban_level=5, ban_level=50,
kick_level=5, kick_level=50,
) )
return [ return [
@ -239,7 +267,7 @@ class RoomCreationHandler(BaseRoomHandler):
] ]
class RoomMemberHandler(BaseRoomHandler): class RoomMemberHandler(BaseHandler):
# TODO(paul): This handler currently contains a messy conflation of # TODO(paul): This handler currently contains a messy conflation of
# low-level API that works on UserID objects and so on, and REST-level # low-level API that works on UserID objects and so on, and REST-level
# API that takes ID strings and returns pagination chunks. These concerns # API that takes ID strings and returns pagination chunks. These concerns
@ -307,7 +335,7 @@ class RoomMemberHandler(BaseRoomHandler):
member_list = yield self.store.get_room_members(room_id=room_id) member_list = yield self.store.get_room_members(room_id=room_id)
event_list = [ event_list = [
entry.get_dict() self.hs.serialize_event(entry)
for entry in member_list for entry in member_list
] ]
chunk_data = { chunk_data = {
@ -560,11 +588,17 @@ class RoomMemberHandler(BaseRoomHandler):
extra_users=[target_user] extra_users=[target_user]
) )
class RoomListHandler(BaseRoomHandler): class RoomListHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_public_room_list(self): def get_public_room_list(self):
chunk = yield self.store.get_rooms(is_public=True) chunk = yield self.store.get_rooms(is_public=True)
for room in chunk:
joined_members = yield self.store.get_room_members(
room_id=room["room_id"],
membership=Membership.JOIN
)
room["num_joined_members"] = len(joined_members)
# FIXME (erikj): START is no longer a valid value # FIXME (erikj): START is no longer a valid value
defer.returnValue({"start": "START", "end": "END", "chunk": chunk}) defer.returnValue({"start": "START", "end": "END", "chunk": chunk})

View file

@ -16,7 +16,7 @@
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.web.client import _AgentBase, _URI, readBody, FileBodyProducer from twisted.web.client import _AgentBase, _URI, readBody, FileBodyProducer, PartialDownloadError
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from synapse.http.endpoint import matrix_endpoint from synapse.http.endpoint import matrix_endpoint
@ -122,7 +122,7 @@ class TwistedHttpClient(HttpClient):
self.hs = hs self.hs = hs
@defer.inlineCallbacks @defer.inlineCallbacks
def put_json(self, destination, path, data): def put_json(self, destination, path, data, on_send_callback=None):
if destination in _destination_mappings: if destination in _destination_mappings:
destination = _destination_mappings[destination] destination = _destination_mappings[destination]
@ -131,7 +131,8 @@ class TwistedHttpClient(HttpClient):
"PUT", "PUT",
path.encode("ascii"), path.encode("ascii"),
producer=_JsonProducer(data), producer=_JsonProducer(data),
headers_dict={"Content-Type": ["application/json"]} headers_dict={"Content-Type": ["application/json"]},
on_send_callback=on_send_callback,
) )
logger.debug("Getting resp body") logger.debug("Getting resp body")
@ -188,11 +189,37 @@ class TwistedHttpClient(HttpClient):
body = yield readBody(response) body = yield readBody(response)
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
# XXX FIXME : I'm so sorry.
@defer.inlineCallbacks
def post_urlencoded_get_raw(self, destination, path, accept_partial=False, args={}):
if destination in _destination_mappings:
destination = _destination_mappings[destination]
query_bytes = urllib.urlencode(args, True)
response = yield self._create_request(
destination.encode("ascii"),
"POST",
path.encode("ascii"),
producer=FileBodyProducer(StringIO(urllib.urlencode(args))),
headers_dict={"Content-Type": ["application/x-www-form-urlencoded"]}
)
try:
body = yield readBody(response)
defer.returnValue(body)
except PartialDownloadError as e:
if accept_partial:
defer.returnValue(e.response)
else:
raise e
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_request(self, destination, method, path_bytes, param_bytes=b"", def _create_request(self, destination, method, path_bytes, param_bytes=b"",
query_bytes=b"", producer=None, headers_dict={}, query_bytes=b"", producer=None, headers_dict={},
retry_on_dns_fail=True): retry_on_dns_fail=True, on_send_callback=None):
""" Creates and sends a request to the given url """ Creates and sends a request to the given url
""" """
headers_dict[b"User-Agent"] = [b"Synapse"] headers_dict[b"User-Agent"] = [b"Synapse"]
@ -216,6 +243,9 @@ class TwistedHttpClient(HttpClient):
endpoint = self._getEndpoint(reactor, destination); endpoint = self._getEndpoint(reactor, destination);
while True: while True:
if on_send_callback:
on_send_callback(destination, method, path_bytes, producer)
try: try:
response = yield self.agent.request( response = yield self.agent.request(
destination, destination,
@ -284,6 +314,9 @@ class _JsonProducer(object):
""" Used by the twisted http client to create the HTTP body from json """ Used by the twisted http client to create the HTTP body from json
""" """
def __init__(self, jsn): def __init__(self, jsn):
self.reset(jsn)
def reset(self, jsn):
self.body = encode_canonical_json(jsn) self.body = encode_canonical_json(jsn)
self.length = len(self.body) self.length = len(self.body)

View file

@ -45,6 +45,8 @@ class ClientDirectoryServer(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_alias): def on_PUT(self, request, room_alias):
user = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
if not "room_id" in content: if not "room_id" in content:
raise SynapseError(400, "Missing room_id key", raise SynapseError(400, "Missing room_id key",
@ -69,12 +71,13 @@ class ClientDirectoryServer(RestServlet):
try: try:
yield dir_handler.create_association( yield dir_handler.create_association(
room_alias, room_id, servers user.to_string(), room_alias, room_id, servers
) )
except SynapseError as e: except SynapseError as e:
raise e raise e
except: except:
logger.exception("Failed to create association") logger.exception("Failed to create association")
raise
defer.returnValue((200, {})) defer.returnValue((200, {}))

View file

@ -59,7 +59,7 @@ class EventRestServlet(RestServlet):
event = yield handler.get_event(auth_user, event_id) event = yield handler.get_event(auth_user, event_id)
if event: if event:
defer.returnValue((200, event.get_dict())) defer.returnValue((200, self.hs.serialize_event(event)))
else: else:
defer.returnValue((404, "Event not found.")) defer.returnValue((404, "Event not found."))

View file

@ -70,7 +70,28 @@ class LoginFallbackRestServlet(RestServlet):
def on_GET(self, request): def on_GET(self, request):
# TODO(kegan): This should be returning some HTML which is capable of # TODO(kegan): This should be returning some HTML which is capable of
# hitting LoginRestServlet # hitting LoginRestServlet
return (200, "") return (200, {})
class PasswordResetRestServlet(RestServlet):
PATTERN = client_path_pattern("/login/reset")
@defer.inlineCallbacks
def on_POST(self, request):
reset_info = _parse_json(request)
try:
email = reset_info["email"]
user_id = reset_info["user_id"]
handler = self.handlers.login_handler
yield handler.reset_password(user_id, email)
# purposefully give no feedback to avoid people hammering different
# combinations.
defer.returnValue((200, {}))
except KeyError:
raise SynapseError(
400,
"Missing keys. Requires 'email' and 'user_id'."
)
def _parse_json(request): def _parse_json(request):
@ -85,3 +106,4 @@ def _parse_json(request):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server) LoginRestServlet(hs).register(http_server)
# TODO PasswordResetRestServlet(hs).register(http_server)

View file

@ -51,7 +51,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
yield self.handlers.profile_handler.set_displayname( yield self.handlers.profile_handler.set_displayname(
user, auth_user, new_name) user, auth_user, new_name)
defer.returnValue((200, "")) defer.returnValue((200, {}))
def on_OPTIONS(self, request, user_id): def on_OPTIONS(self, request, user_id):
return (200, {}) return (200, {})
@ -86,7 +86,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
yield self.handlers.profile_handler.set_avatar_url( yield self.handlers.profile_handler.set_avatar_url(
user, auth_user, new_name) user, auth_user, new_name)
defer.returnValue((200, "")) defer.returnValue((200, {}))
def on_OPTIONS(self, request, user_id): def on_OPTIONS(self, request, user_id):
return (200, {}) return (200, {})

View file

@ -16,58 +16,219 @@
"""This module contains REST servlets to do with registration: /register""" """This module contains REST servlets to do with registration: /register"""
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError, Codes
from synapse.api.constants import LoginType
from base import RestServlet, client_path_pattern from base import RestServlet, client_path_pattern
import synapse.util.stringutils as stringutils
import json import json
import logging
import urllib import urllib
logger = logging.getLogger(__name__)
class RegisterRestServlet(RestServlet): class RegisterRestServlet(RestServlet):
"""Handles registration with the home server.
This servlet is in control of the registration flow; the registration
handler doesn't have a concept of multi-stages or sessions.
"""
PATTERN = client_path_pattern("/register$") PATTERN = client_path_pattern("/register$")
def __init__(self, hs):
super(RegisterRestServlet, self).__init__(hs)
# sessions are stored as:
# self.sessions = {
# "session_id" : { __session_dict__ }
# }
# TODO: persistent storage
self.sessions = {}
def on_GET(self, request):
if self.hs.config.enable_registration_captcha:
return (200, {
"flows": [
{
"type": LoginType.RECAPTCHA,
"stages": ([LoginType.RECAPTCHA,
LoginType.EMAIL_IDENTITY,
LoginType.PASSWORD])
},
{
"type": LoginType.RECAPTCHA,
"stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
}
]
})
else:
return (200, {
"flows": [
{
"type": LoginType.EMAIL_IDENTITY,
"stages": ([LoginType.EMAIL_IDENTITY,
LoginType.PASSWORD])
},
{
"type": LoginType.PASSWORD
}
]
})
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
desired_user_id = None register_json = _parse_json(request)
password = None
session = (register_json["session"] if "session" in register_json
else None)
login_type = None
if "type" not in register_json:
raise SynapseError(400, "Missing 'type' key.")
try: try:
register_json = json.loads(request.content.read()) login_type = register_json["type"]
if "password" in register_json: stages = {
password = register_json["password"].encode("utf-8") LoginType.RECAPTCHA: self._do_recaptcha,
LoginType.PASSWORD: self._do_password,
LoginType.EMAIL_IDENTITY: self._do_email_identity
}
if type(register_json["user_id"]) == unicode: session_info = self._get_session_info(request, session)
desired_user_id = register_json["user_id"].encode("utf-8") logger.debug("%s : session info %s request info %s",
if urllib.quote(desired_user_id) != desired_user_id: login_type, session_info, register_json)
raise SynapseError( response = yield stages[login_type](
400, request,
"User ID must only contain characters which do not " + register_json,
"require URL encoding.") session_info
except ValueError: )
defer.returnValue((400, "No JSON object."))
if "access_token" not in response:
# isn't a final response
response["session"] = session_info["id"]
defer.returnValue((200, response))
except KeyError as e:
logger.exception(e)
raise SynapseError(400, "Missing JSON keys for login type %s." % login_type)
def on_OPTIONS(self, request):
return (200, {})
def _get_session_info(self, request, session_id):
if not session_id:
# create a new session
while session_id is None or session_id in self.sessions:
session_id = stringutils.random_string(24)
self.sessions[session_id] = {
"id": session_id,
LoginType.EMAIL_IDENTITY: False,
LoginType.RECAPTCHA: False
}
return self.sessions[session_id]
def _save_session(self, session):
# TODO: Persistent storage
logger.debug("Saving session %s", session)
self.sessions[session["id"]] = session
def _remove_session(self, session):
logger.debug("Removing session %s", session)
self.sessions.pop(session["id"])
@defer.inlineCallbacks
def _do_recaptcha(self, request, register_json, session):
if not self.hs.config.enable_registration_captcha:
raise SynapseError(400, "Captcha not required.")
challenge = None
user_response = None
try:
challenge = register_json["challenge"]
user_response = register_json["response"]
except KeyError: except KeyError:
pass # user_id is optional raise SynapseError(400, "Captcha response is required",
errcode=Codes.CAPTCHA_NEEDED)
threepidCreds = None # May be an X-Forwarding-For header depending on config
if 'threepidCreds' in register_json: ip_addr = request.getClientIP()
threepidCreds = register_json['threepidCreds'] if self.hs.config.captcha_ip_origin_is_x_forwarded:
# use the header
if request.requestHeaders.hasHeader("X-Forwarded-For"):
ip_addr = request.requestHeaders.getRawHeaders(
"X-Forwarded-For")[0]
handler = self.handlers.registration_handler handler = self.handlers.registration_handler
yield handler.check_recaptcha(
ip_addr,
self.hs.config.recaptcha_private_key,
challenge,
user_response
)
session[LoginType.RECAPTCHA] = True # mark captcha as done
self._save_session(session)
defer.returnValue({
"next": [LoginType.PASSWORD, LoginType.EMAIL_IDENTITY]
})
@defer.inlineCallbacks
def _do_email_identity(self, request, register_json, session):
if (self.hs.config.enable_registration_captcha and
not session[LoginType.RECAPTCHA]):
raise SynapseError(400, "Captcha is required.")
threepidCreds = register_json['threepidCreds']
handler = self.handlers.registration_handler
yield handler.register_email(threepidCreds)
session["threepidCreds"] = threepidCreds # store creds for next stage
session[LoginType.EMAIL_IDENTITY] = True # mark email as done
self._save_session(session)
defer.returnValue({
"next": LoginType.PASSWORD
})
@defer.inlineCallbacks
def _do_password(self, request, register_json, session):
if (self.hs.config.enable_registration_captcha and
not session[LoginType.RECAPTCHA]):
# captcha should've been done by this stage!
raise SynapseError(400, "Captcha is required.")
password = register_json["password"].encode("utf-8")
desired_user_id = (register_json["user"].encode("utf-8") if "user"
in register_json else None)
if desired_user_id and urllib.quote(desired_user_id) != desired_user_id:
raise SynapseError(
400,
"User ID must only contain characters which do not " +
"require URL encoding.")
handler = self.handlers.registration_handler
(user_id, token) = yield handler.register( (user_id, token) = yield handler.register(
localpart=desired_user_id, localpart=desired_user_id,
password=password, password=password
threepidCreds=threepidCreds) )
if session[LoginType.EMAIL_IDENTITY]:
yield handler.bind_emails(user_id, session["threepidCreds"])
result = { result = {
"user_id": user_id, "user_id": user_id,
"access_token": token, "access_token": token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
} }
defer.returnValue( self._remove_session(session)
(200, result) defer.returnValue(result)
)
def on_OPTIONS(self, request):
return (200, {}) def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.")
return content
except ValueError:
raise SynapseError(400, "Content not JSON.")
def register_servlets(hs, http_server): def register_servlets(hs, http_server):

View file

@ -154,14 +154,14 @@ class RoomStateEventRestServlet(RestServlet):
# membership events are special # membership events are special
handler = self.handlers.room_member_handler handler = self.handlers.room_member_handler
yield handler.change_membership(event) yield handler.change_membership(event)
defer.returnValue((200, "")) defer.returnValue((200, {}))
else: else:
# store random bits of state # store random bits of state
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
yield msg_handler.store_room_data( yield msg_handler.store_room_data(
event=event event=event
) )
defer.returnValue((200, "")) defer.returnValue((200, {}))
# TODO: Needs unit testing for generic events + feedback # TODO: Needs unit testing for generic events + feedback
@ -249,7 +249,7 @@ class JoinRoomAliasServlet(RestServlet):
) )
handler = self.handlers.room_member_handler handler = self.handlers.room_member_handler
yield handler.change_membership(event) yield handler.change_membership(event)
defer.returnValue((200, "")) defer.returnValue((200, {}))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_identifier, txn_id): def on_PUT(self, request, room_identifier, txn_id):
@ -378,7 +378,7 @@ class RoomTriggerBackfill(RestServlet):
handler = self.handlers.federation_handler handler = self.handlers.federation_handler
events = yield handler.backfill(remote_server, room_id, limit) events = yield handler.backfill(remote_server, room_id, limit)
res = [event.get_dict() for event in events] res = [self.hs.serialize_event(event) for event in events]
defer.returnValue((200, res)) defer.returnValue((200, res))
@ -416,7 +416,7 @@ class RoomMembershipRestServlet(RestServlet):
) )
handler = self.handlers.room_member_handler handler = self.handlers.room_member_handler
yield handler.change_membership(event) yield handler.change_membership(event)
defer.returnValue((200, "")) defer.returnValue((200, {}))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, membership_action, txn_id): def on_PUT(self, request, room_id, membership_action, txn_id):

View file

@ -20,6 +20,7 @@
# Imports required for the default HomeServer() implementation # Imports required for the default HomeServer() implementation
from synapse.federation import initialize_http_replication from synapse.federation import initialize_http_replication
from synapse.api.events import serialize_event
from synapse.api.events.factory import EventFactory from synapse.api.events.factory import EventFactory
from synapse.notifier import Notifier from synapse.notifier import Notifier
from synapse.api.auth import Auth from synapse.api.auth import Auth
@ -57,6 +58,7 @@ class BaseHomeServer(object):
DEPENDENCIES = [ DEPENDENCIES = [
'clock', 'clock',
'http_client', 'http_client',
'db_name',
'db_pool', 'db_pool',
'persistence_service', 'persistence_service',
'replication_layer', 'replication_layer',
@ -138,6 +140,9 @@ class BaseHomeServer(object):
object.""" object."""
return RoomID.from_string(s, hs=self) return RoomID.from_string(s, hs=self)
def serialize_event(self, e):
return serialize_event(self, e)
# Build magic accessors for every dependency # Build magic accessors for every dependency
for depname in BaseHomeServer.DEPENDENCIES: for depname in BaseHomeServer.DEPENDENCIES:
BaseHomeServer._make_dependency_method(depname) BaseHomeServer._make_dependency_method(depname)

View file

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.federation.pdu_codec import encode_event_id from synapse.federation.pdu_codec import encode_event_id, decode_event_id
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from collections import namedtuple from collections import namedtuple
@ -87,9 +87,11 @@ class StateHandler(object):
# than the power level of the user # than the power level of the user
# power_level = self._get_power_level_for_event(event) # power_level = self._get_power_level_for_event(event)
pdu_id, origin = decode_event_id(event.event_id, self.server_name)
yield self.store.update_current_state( yield self.store.update_current_state(
pdu_id=event.event_id, pdu_id=pdu_id,
origin=self.server_name, origin=origin,
context=key.context, context=key.context,
pdu_type=key.type, pdu_type=key.type,
state_key=key.state_key state_key=key.state_key
@ -113,6 +115,8 @@ class StateHandler(object):
is_new = yield self._handle_new_state(new_pdu) is_new = yield self._handle_new_state(new_pdu)
logger.debug("is_new: %s %s %s", is_new, new_pdu.pdu_id, new_pdu.origin)
if is_new: if is_new:
yield self.store.update_current_state( yield self.store.update_current_state(
pdu_id=new_pdu.pdu_id, pdu_id=new_pdu.pdu_id,
@ -132,7 +136,9 @@ class StateHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _handle_new_state(self, new_pdu): def _handle_new_state(self, new_pdu):
tree = yield self.store.get_unresolved_state_tree(new_pdu) tree, missing_branch = yield self.store.get_unresolved_state_tree(
new_pdu
)
new_branch, current_branch = tree new_branch, current_branch = tree
logger.debug( logger.debug(
@ -140,64 +146,17 @@ class StateHandler(object):
new_branch, current_branch new_branch, current_branch
) )
if not current_branch: if missing_branch is not None:
# There is no current state # We're missing some PDUs. Fetch them.
defer.returnValue(True) # TODO (erikj): Limit this.
return missing_prev = tree[missing_branch][-1]
n = new_branch[-1]
c = current_branch[-1]
if n.pdu_id == c.pdu_id and n.origin == c.origin:
# We have all the PDUs we need, so we can just do the conflict
# resolution.
if len(current_branch) == 1:
# This is a direct clobber so we can just...
defer.returnValue(True)
conflict_res = [
self._do_power_level_conflict_res,
self._do_chain_length_conflict_res,
self._do_hash_conflict_res,
]
for algo in conflict_res:
new_res, curr_res = algo(new_branch, current_branch)
if new_res < curr_res:
defer.returnValue(False)
elif new_res > curr_res:
defer.returnValue(True)
raise Exception("Conflict resolution failed.")
else:
# We need to ask for PDUs.
missing_prev = max(
new_branch[-1], current_branch[-1],
key=lambda x: x.depth
)
if not hasattr(missing_prev, "prev_state_id"):
# FIXME Hmm
# temporary fallback
for algo in conflict_res:
new_res, curr_res = algo(new_branch, current_branch)
if new_res < curr_res:
defer.returnValue(False)
elif new_res > curr_res:
defer.returnValue(True)
return
pdu_id = missing_prev.prev_state_id pdu_id = missing_prev.prev_state_id
origin = missing_prev.prev_state_origin origin = missing_prev.prev_state_origin
is_missing = yield self.store.get_pdu(pdu_id, origin) is None is_missing = yield self.store.get_pdu(pdu_id, origin) is None
if not is_missing: if not is_missing:
raise Exception("Conflict resolution failed.") raise Exception("Conflict resolution failed")
yield self._replication.get_pdu( yield self._replication.get_pdu(
destination=missing_prev.origin, destination=missing_prev.origin,
@ -209,23 +168,93 @@ class StateHandler(object):
updated_current = yield self._handle_new_state(new_pdu) updated_current = yield self._handle_new_state(new_pdu)
defer.returnValue(updated_current) defer.returnValue(updated_current)
def _do_power_level_conflict_res(self, new_branch, current_branch): if not current_branch:
max_power_new = max( # There is no current state
new_branch[:-1], defer.returnValue(True)
key=lambda t: t.power_level return
).power_level
max_power_current = max( n = new_branch[-1]
current_branch[:-1], c = current_branch[-1]
key=lambda t: t.power_level
).power_level
return (max_power_new, max_power_current) common_ancestor = n.pdu_id == c.pdu_id and n.origin == c.origin
def _do_chain_length_conflict_res(self, new_branch, current_branch): if common_ancestor:
# We found a common ancestor!
if len(current_branch) == 1:
# This is a direct clobber so we can just...
defer.returnValue(True)
else:
# We didn't find a common ancestor. This is probably fine.
pass
result = yield self._do_conflict_res(
new_branch, current_branch, common_ancestor
)
defer.returnValue(result)
@defer.inlineCallbacks
def _do_conflict_res(self, new_branch, current_branch, common_ancestor):
conflict_res = [
self._do_power_level_conflict_res,
self._do_chain_length_conflict_res,
self._do_hash_conflict_res,
]
for algo in conflict_res:
new_res, curr_res = yield defer.maybeDeferred(
algo,
new_branch, current_branch, common_ancestor
)
if new_res < curr_res:
defer.returnValue(False)
elif new_res > curr_res:
defer.returnValue(True)
raise Exception("Conflict resolution failed.")
@defer.inlineCallbacks
def _do_power_level_conflict_res(self, new_branch, current_branch,
common_ancestor):
new_powers_deferreds = []
for e in new_branch[:-1] if common_ancestor else new_branch:
if hasattr(e, "user_id"):
new_powers_deferreds.append(
self.store.get_power_level(e.context, e.user_id)
)
current_powers_deferreds = []
for e in current_branch[:-1] if common_ancestor else current_branch:
if hasattr(e, "user_id"):
current_powers_deferreds.append(
self.store.get_power_level(e.context, e.user_id)
)
new_powers = yield defer.gatherResults(
new_powers_deferreds,
consumeErrors=True
)
current_powers = yield defer.gatherResults(
current_powers_deferreds,
consumeErrors=True
)
max_power_new = max(new_powers)
max_power_current = max(current_powers)
defer.returnValue(
(max_power_new, max_power_current)
)
def _do_chain_length_conflict_res(self, new_branch, current_branch,
common_ancestor):
return (len(new_branch), len(current_branch)) return (len(new_branch), len(current_branch))
def _do_hash_conflict_res(self, new_branch, current_branch): def _do_hash_conflict_res(self, new_branch, current_branch,
common_ancestor):
new_str = "".join([p.pdu_id + p.origin for p in new_branch]) new_str = "".join([p.pdu_id + p.origin for p in new_branch])
c_str = "".join([p.pdu_id + p.origin for p in current_branch]) c_str = "".join([p.pdu_id + p.origin for p in current_branch])

View file

@ -36,7 +36,7 @@ from .registration import RegistrationStore
from .room import RoomStore from .room import RoomStore
from .roommember import RoomMemberStore from .roommember import RoomMemberStore
from .stream import StreamStore from .stream import StreamStore
from .pdu import StatePduStore, PduStore from .pdu import StatePduStore, PduStore, PdusTable
from .transactions import TransactionStore from .transactions import TransactionStore
from .keys import KeyStore from .keys import KeyStore
@ -48,6 +48,28 @@ import os
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SCHEMAS = [
"transactions",
"pdu",
"users",
"profiles",
"presence",
"im",
"room_aliases",
]
# Remember to update this number every time an incompatible change is made to
# database schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 3
class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying
something went wrong.
"""
pass
class DataStore(RoomMemberStore, RoomStore, class DataStore(RoomMemberStore, RoomStore,
RegistrationStore, StreamStore, ProfileStore, FeedbackStore, RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
PresenceStore, PduStore, StatePduStore, TransactionStore, PresenceStore, PduStore, StatePduStore, TransactionStore,
@ -63,7 +85,8 @@ class DataStore(RoomMemberStore, RoomStore,
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def persist_event(self, event=None, backfilled=False, pdu=None): def persist_event(self, event=None, backfilled=False, pdu=None,
is_new_state=True):
stream_ordering = None stream_ordering = None
if backfilled: if backfilled:
if not self.min_token_deferred.called: if not self.min_token_deferred.called:
@ -71,17 +94,20 @@ class DataStore(RoomMemberStore, RoomStore,
self.min_token -= 1 self.min_token -= 1
stream_ordering = self.min_token stream_ordering = self.min_token
latest = yield self._db_pool.runInteraction( try:
self._persist_pdu_event_txn, yield self.runInteraction(
pdu=pdu, self._persist_pdu_event_txn,
event=event, pdu=pdu,
backfilled=backfilled, event=event,
stream_ordering=stream_ordering, backfilled=backfilled,
) stream_ordering=stream_ordering,
defer.returnValue(latest) is_new_state=is_new_state,
)
except _RollbackButIsFineException as e:
pass
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event(self, event_id): def get_event(self, event_id, allow_none=False):
events_dict = yield self._simple_select_one( events_dict = yield self._simple_select_one(
"events", "events",
{"event_id": event_id}, {"event_id": event_id},
@ -92,18 +118,24 @@ class DataStore(RoomMemberStore, RoomStore,
"content", "content",
"unrecognized_keys" "unrecognized_keys"
], ],
allow_none=allow_none,
) )
if not events_dict:
defer.returnValue(None)
event = self._parse_event_from_row(events_dict) event = self._parse_event_from_row(events_dict)
defer.returnValue(event) defer.returnValue(event)
def _persist_pdu_event_txn(self, txn, pdu=None, event=None, def _persist_pdu_event_txn(self, txn, pdu=None, event=None,
backfilled=False, stream_ordering=None): backfilled=False, stream_ordering=None,
is_new_state=True):
if pdu is not None: if pdu is not None:
self._persist_event_pdu_txn(txn, pdu) self._persist_event_pdu_txn(txn, pdu)
if event is not None: if event is not None:
return self._persist_event_txn( return self._persist_event_txn(
txn, event, backfilled, stream_ordering txn, event, backfilled, stream_ordering,
is_new_state=is_new_state,
) )
def _persist_event_pdu_txn(self, txn, pdu): def _persist_event_pdu_txn(self, txn, pdu):
@ -112,6 +144,12 @@ class DataStore(RoomMemberStore, RoomStore,
del cols["content"] del cols["content"]
del cols["prev_pdus"] del cols["prev_pdus"]
cols["content_json"] = json.dumps(pdu.content) cols["content_json"] = json.dumps(pdu.content)
unrec_keys.update({
k: v for k, v in cols.items()
if k not in PdusTable.fields
})
cols["unrecognized_keys"] = json.dumps(unrec_keys) cols["unrecognized_keys"] = json.dumps(unrec_keys)
logger.debug("Persisting: %s", repr(cols)) logger.debug("Persisting: %s", repr(cols))
@ -124,7 +162,8 @@ class DataStore(RoomMemberStore, RoomStore,
self._update_min_depth_for_context_txn(txn, pdu.context, pdu.depth) self._update_min_depth_for_context_txn(txn, pdu.context, pdu.depth)
@log_function @log_function
def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None): def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None,
is_new_state=True):
if event.type == RoomMemberEvent.TYPE: if event.type == RoomMemberEvent.TYPE:
self._store_room_member_txn(txn, event) self._store_room_member_txn(txn, event)
elif event.type == FeedbackEvent.TYPE: elif event.type == FeedbackEvent.TYPE:
@ -171,13 +210,14 @@ class DataStore(RoomMemberStore, RoomStore,
try: try:
self._simple_insert_txn(txn, "events", vals) self._simple_insert_txn(txn, "events", vals)
except: except:
logger.exception( logger.warn(
"Failed to persist, probably duplicate: %s", "Failed to persist, probably duplicate: %s",
event.event_id event.event_id,
exc_info=True,
) )
return raise _RollbackButIsFineException("_persist_event")
if not backfilled and hasattr(event, "state_key"): if is_new_state and hasattr(event, "state_key"):
vals = { vals = {
"event_id": event.event_id, "event_id": event.event_id,
"room_id": event.room_id, "room_id": event.room_id,
@ -201,8 +241,6 @@ class DataStore(RoomMemberStore, RoomStore,
} }
) )
return self._get_room_events_max_id_txn(txn)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""): def get_current_state(self, room_id, event_type=None, state_key=""):
sql = ( sql = (
@ -220,7 +258,8 @@ class DataStore(RoomMemberStore, RoomStore,
results = yield self._execute_and_decode(sql, *args) results = yield self._execute_and_decode(sql, *args)
defer.returnValue([self._parse_event_from_row(r) for r in results]) events = yield self._parse_events(results)
defer.returnValue(events)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_min_token(self): def _get_min_token(self):
@ -269,7 +308,7 @@ class DataStore(RoomMemberStore, RoomStore,
prev_state_pdu=prev_state_pdu, prev_state_pdu=prev_state_pdu,
) )
return self._db_pool.runInteraction(_snapshot) return self.runInteraction(_snapshot)
class Snapshot(object): class Snapshot(object):
@ -339,3 +378,42 @@ def read_schema(schema):
""" """
with open(schema_path(schema)) as schema_file: with open(schema_path(schema)) as schema_file:
return schema_file.read() return schema_file.read()
def prepare_database(db_conn):
""" Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we
don't have to worry about overwriting existing content.
"""
c = db_conn.cursor()
c.execute("PRAGMA user_version")
row = c.fetchone()
if row and row[0]:
user_version = row[0]
if user_version > SCHEMA_VERSION:
raise ValueError("Cannot use this database as it is too " +
"new for the server to understand"
)
elif user_version < SCHEMA_VERSION:
logging.info("Upgrading database from version %d",
user_version
)
# Run every version since after the current version.
for v in range(user_version + 1, SCHEMA_VERSION + 1):
sql_script = read_schema("delta/v%d" % (v))
c.executescript(sql_script)
db_conn.commit()
else:
for sql_loc in SCHEMAS:
sql_script = read_schema(sql_loc)
c.executescript(sql_script)
db_conn.commit()
c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
c.close()

View file

@ -17,6 +17,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.util.logutils import log_function
import collections import collections
import copy import copy
@ -25,6 +26,44 @@ import json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
sql_logger = logging.getLogger("synapse.storage.SQL")
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging to the .execute() method."""
__slots__ = ["txn"]
def __init__(self, txn):
object.__setattr__(self, "txn", txn)
def __getattribute__(self, name):
if name == "execute":
return object.__getattribute__(self, "execute")
return getattr(object.__getattribute__(self, "txn"), name)
def __setattr__(self, name, value):
setattr(object.__getattribute__(self, "txn"), name, value)
def execute(self, sql, *args, **kwargs):
# TODO(paul): Maybe use 'info' and 'debug' for values?
sql_logger.debug("[SQL] %s", sql)
try:
if args and args[0]:
values = args[0]
sql_logger.debug("[SQL values] " +
", ".join(("<%s>",) * len(values)), *values)
except:
# Don't let logging failures stop SQL from working
pass
# TODO(paul): Here would be an excellent place to put some timing
# measurements, and log (warning?) slow queries.
return object.__getattribute__(self, "txn").execute(
sql, *args, **kwargs
)
class SQLBaseStore(object): class SQLBaseStore(object):
@ -34,6 +73,13 @@ class SQLBaseStore(object):
self.event_factory = hs.get_event_factory() self.event_factory = hs.get_event_factory()
self._clock = hs.get_clock() self._clock = hs.get_clock()
def runInteraction(self, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
def inner_func(txn, *args, **kwargs):
return func(LoggingTransaction(txn), *args, **kwargs)
return self._db_pool.runInteraction(inner_func, *args, **kwargs)
def cursor_to_dict(self, cursor): def cursor_to_dict(self, cursor):
"""Converts a SQL cursor into an list of dicts. """Converts a SQL cursor into an list of dicts.
@ -59,11 +105,6 @@ class SQLBaseStore(object):
Returns: Returns:
The result of decoder(results) The result of decoder(results)
""" """
logger.debug(
"[SQL] %s Args=%s Func=%s",
query, args, decoder.__name__ if decoder else None
)
def interaction(txn): def interaction(txn):
cursor = txn.execute(query, args) cursor = txn.execute(query, args)
if decoder: if decoder:
@ -71,7 +112,7 @@ class SQLBaseStore(object):
else: else:
return cursor.fetchall() return cursor.fetchall()
return self._db_pool.runInteraction(interaction) return self.runInteraction(interaction)
def _execute_and_decode(self, query, *args): def _execute_and_decode(self, query, *args):
return self._execute(self.cursor_to_dict, query, *args) return self._execute(self.cursor_to_dict, query, *args)
@ -87,10 +128,11 @@ class SQLBaseStore(object):
values : dict of new column names and values for them values : dict of new column names and values for them
or_replace : bool; if True performs an INSERT OR REPLACE or_replace : bool; if True performs an INSERT OR REPLACE
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._simple_insert_txn, table, values, or_replace=or_replace self._simple_insert_txn, table, values, or_replace=or_replace
) )
@log_function
def _simple_insert_txn(self, txn, table, values, or_replace=False): def _simple_insert_txn(self, txn, table, values, or_replace=False):
sql = "%s INTO %s (%s) VALUES(%s)" % ( sql = "%s INTO %s (%s) VALUES(%s)" % (
("INSERT OR REPLACE" if or_replace else "INSERT"), ("INSERT OR REPLACE" if or_replace else "INSERT"),
@ -98,6 +140,12 @@ class SQLBaseStore(object):
", ".join(k for k in values), ", ".join(k for k in values),
", ".join("?" for k in values) ", ".join("?" for k in values)
) )
logger.debug(
"[SQL] %s Args=%s Func=%s",
sql, values.values(),
)
txn.execute(sql, values.values()) txn.execute(sql, values.values())
return txn.lastrowid return txn.lastrowid
@ -164,7 +212,7 @@ class SQLBaseStore(object):
txn.execute(sql, keyvalues.values()) txn.execute(sql, keyvalues.values())
return txn.fetchall() return txn.fetchall()
res = yield self._db_pool.runInteraction(func) res = yield self.runInteraction(func)
defer.returnValue([r[0] for r in res]) defer.returnValue([r[0] for r in res])
@ -187,7 +235,7 @@ class SQLBaseStore(object):
txn.execute(sql, keyvalues.values()) txn.execute(sql, keyvalues.values())
return self.cursor_to_dict(txn) return self.cursor_to_dict(txn)
return self._db_pool.runInteraction(func) return self.runInteraction(func)
def _simple_update_one(self, table, keyvalues, updatevalues, def _simple_update_one(self, table, keyvalues, updatevalues,
retcols=None): retcols=None):
@ -255,7 +303,7 @@ class SQLBaseStore(object):
raise StoreError(500, "More than one row matched") raise StoreError(500, "More than one row matched")
return ret return ret
return self._db_pool.runInteraction(func) return self.runInteraction(func)
def _simple_delete_one(self, table, keyvalues): def _simple_delete_one(self, table, keyvalues):
"""Executes a DELETE query on the named table, expecting to delete a """Executes a DELETE query on the named table, expecting to delete a
@ -276,7 +324,7 @@ class SQLBaseStore(object):
raise StoreError(404, "No row found") raise StoreError(404, "No row found")
if txn.rowcount > 1: if txn.rowcount > 1:
raise StoreError(500, "more than one row matched") raise StoreError(500, "more than one row matched")
return self._db_pool.runInteraction(func) return self.runInteraction(func)
def _simple_max_id(self, table): def _simple_max_id(self, table):
"""Executes a SELECT query on the named table, expecting to return the """Executes a SELECT query on the named table, expecting to return the
@ -294,7 +342,7 @@ class SQLBaseStore(object):
return 0 return 0
return max_id return max_id
return self._db_pool.runInteraction(func) return self.runInteraction(func)
def _parse_event_from_row(self, row_dict): def _parse_event_from_row(self, row_dict):
d = copy.deepcopy({k: v for k, v in row_dict.items() if v}) d = copy.deepcopy({k: v for k, v in row_dict.items() if v})
@ -307,11 +355,34 @@ class SQLBaseStore(object):
d["content"] = json.loads(d["content"]) d["content"] = json.loads(d["content"])
del d["unrecognized_keys"] del d["unrecognized_keys"]
if "age_ts" not in d:
# For compatibility
d["age_ts"] = d["ts"] if "ts" in d else 0
return self.event_factory.create_event( return self.event_factory.create_event(
etype=d["type"], etype=d["type"],
**d **d
) )
def _parse_events(self, rows):
return self.runInteraction(self._parse_events_txn, rows)
def _parse_events_txn(self, txn, rows):
events = [self._parse_event_from_row(r) for r in rows]
sql = "SELECT * FROM events WHERE event_id = ?"
for ev in events:
if hasattr(ev, "prev_state"):
# Load previous state_content.
# TODO: Should we be pulling this out above?
cursor = txn.execute(sql, (ev.prev_state,))
prevs = self.cursor_to_dict(cursor)
if prevs:
prev = self._parse_event_from_row(prevs[0])
ev.prev_content = prev.content
return events
class Table(object): class Table(object):
""" A base class used to store information about a particular table. """ A base class used to store information about a particular table.

View file

@ -92,3 +92,10 @@ class DirectoryStore(SQLBaseStore):
"server": server, "server": server,
} }
) )
def get_aliases_for_room(self, room_id):
return self._simple_select_onecol(
"room_aliases",
{"room_id": room_id},
"room_alias",
)

View file

@ -17,6 +17,7 @@ from twisted.internet import defer
from ._base import SQLBaseStore, Table, JoinHelper from ._base import SQLBaseStore, Table, JoinHelper
from synapse.federation.units import Pdu
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from collections import namedtuple from collections import namedtuple
@ -42,7 +43,7 @@ class PduStore(SQLBaseStore):
PduTuple: If the pdu does not exist in the database, returns None PduTuple: If the pdu does not exist in the database, returns None
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._get_pdu_tuple, pdu_id, origin self._get_pdu_tuple, pdu_id, origin
) )
@ -94,7 +95,7 @@ class PduStore(SQLBaseStore):
list: A list of PduTuples list: A list of PduTuples
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._get_current_state_for_context, self._get_current_state_for_context,
context context
) )
@ -142,7 +143,7 @@ class PduStore(SQLBaseStore):
pdu_origin (str) pdu_origin (str)
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._mark_as_processed, pdu_id, pdu_origin self._mark_as_processed, pdu_id, pdu_origin
) )
@ -151,7 +152,7 @@ class PduStore(SQLBaseStore):
def get_all_pdus_from_context(self, context): def get_all_pdus_from_context(self, context):
"""Get a list of all PDUs for a given context.""" """Get a list of all PDUs for a given context."""
return self._db_pool.runInteraction( return self.runInteraction(
self._get_all_pdus_from_context, context, self._get_all_pdus_from_context, context,
) )
@ -178,7 +179,7 @@ class PduStore(SQLBaseStore):
Return: Return:
list: A list of PduTuples list: A list of PduTuples
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._get_backfill, context, pdu_list, limit self._get_backfill, context, pdu_list, limit
) )
@ -239,7 +240,7 @@ class PduStore(SQLBaseStore):
txn txn
context (str) context (str)
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._get_min_depth_for_context, context self._get_min_depth_for_context, context
) )
@ -308,8 +309,8 @@ class PduStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_oldest_pdus_in_context(self, context): def get_oldest_pdus_in_context(self, context):
"""Get a list of Pdus that we haven't backfilled beyond yet (and haven't """Get a list of Pdus that we haven't backfilled beyond yet (and havent
seen). This list is used when we want to backfill backwards and is the seen). This list is used when we want to backfill backwards and is the
list we send to the remote server. list we send to the remote server.
Args: Args:
@ -345,7 +346,7 @@ class PduStore(SQLBaseStore):
bool bool
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._is_pdu_new, self._is_pdu_new,
pdu_id=pdu_id, pdu_id=pdu_id,
origin=origin, origin=origin,
@ -498,7 +499,7 @@ class StatePduStore(SQLBaseStore):
) )
def get_unresolved_state_tree(self, new_state_pdu): def get_unresolved_state_tree(self, new_state_pdu):
return self._db_pool.runInteraction( return self.runInteraction(
self._get_unresolved_state_tree, new_state_pdu self._get_unresolved_state_tree, new_state_pdu
) )
@ -516,7 +517,7 @@ class StatePduStore(SQLBaseStore):
if not current: if not current:
logger.debug("get_unresolved_state_tree No current state.") logger.debug("get_unresolved_state_tree No current state.")
return return_value return (return_value, None)
return_value.current_branch.append(current) return_value.current_branch.append(current)
@ -524,17 +525,20 @@ class StatePduStore(SQLBaseStore):
txn, new_pdu, current txn, new_pdu, current
) )
missing_branch = None
for branch, prev_state, state in enum_branches: for branch, prev_state, state in enum_branches:
if state: if state:
return_value[branch].append(state) return_value[branch].append(state)
else: else:
# We don't have prev_state :(
missing_branch = branch
break break
return return_value return (return_value, missing_branch)
def update_current_state(self, pdu_id, origin, context, pdu_type, def update_current_state(self, pdu_id, origin, context, pdu_type,
state_key): state_key):
return self._db_pool.runInteraction( return self.runInteraction(
self._update_current_state, self._update_current_state,
pdu_id, origin, context, pdu_type, state_key pdu_id, origin, context, pdu_type, state_key
) )
@ -573,7 +577,7 @@ class StatePduStore(SQLBaseStore):
PduEntry PduEntry
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._get_current_state_pdu, context, pdu_type, state_key self._get_current_state_pdu, context, pdu_type, state_key
) )
@ -622,53 +626,6 @@ class StatePduStore(SQLBaseStore):
return result return result
def get_next_missing_pdu(self, new_pdu):
"""When we get a new state pdu we need to check whether we need to do
any conflict resolution, if we do then we need to check if we need
to go back and request some more state pdus that we haven't seen yet.
Args:
txn
new_pdu
Returns:
PduIdTuple: A pdu that we are missing, or None if we have all the
pdus required to do the conflict resolution.
"""
return self._db_pool.runInteraction(
self._get_next_missing_pdu, new_pdu
)
def _get_next_missing_pdu(self, txn, new_pdu):
logger.debug(
"get_next_missing_pdu %s %s",
new_pdu.pdu_id, new_pdu.origin
)
current = self._get_current_interaction(
txn,
new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
)
if (not current or not current.prev_state_id
or not current.prev_state_origin):
return None
# Oh look, it's a straight clobber, so wooooo almost no-op.
if (new_pdu.prev_state_id == current.pdu_id
and new_pdu.prev_state_origin == current.origin):
return None
enum_branches = self._enumerate_state_branches(txn, new_pdu, current)
for branch, prev_state, state in enum_branches:
if not state:
return PduIdTuple(
prev_state.prev_state_id,
prev_state.prev_state_origin
)
return None
def handle_new_state(self, new_pdu): def handle_new_state(self, new_pdu):
"""Actually perform conflict resolution on the new_pdu on the """Actually perform conflict resolution on the new_pdu on the
assumption we have all the pdus required to perform it. assumption we have all the pdus required to perform it.
@ -679,7 +636,7 @@ class StatePduStore(SQLBaseStore):
Returns: Returns:
bool: True if the new_pdu clobbered the current state, False if not bool: True if the new_pdu clobbered the current state, False if not
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._handle_new_state, new_pdu self._handle_new_state, new_pdu
) )
@ -752,24 +709,11 @@ class StatePduStore(SQLBaseStore):
return is_current return is_current
@classmethod
@log_function @log_function
def _enumerate_state_branches(cls, txn, pdu_a, pdu_b): def _enumerate_state_branches(self, txn, pdu_a, pdu_b):
branch_a = pdu_a branch_a = pdu_a
branch_b = pdu_b branch_b = pdu_b
get_query = (
"SELECT %(fields)s FROM %(pdus)s as p "
"LEFT JOIN %(state)s as s "
"ON p.pdu_id = s.pdu_id AND p.origin = s.origin "
"WHERE p.pdu_id = ? AND p.origin = ? "
) % {
"fields": _pdu_state_joiner.get_fields(
PdusTable="p", StatePdusTable="s"),
"pdus": PdusTable.table_name,
"state": StatePdusTable.table_name,
}
while True: while True:
if (branch_a.pdu_id == branch_b.pdu_id if (branch_a.pdu_id == branch_b.pdu_id
and branch_a.origin == branch_b.origin): and branch_a.origin == branch_b.origin):
@ -801,13 +745,12 @@ class StatePduStore(SQLBaseStore):
branch_a.prev_state_origin branch_a.prev_state_origin
) )
logger.debug("getting branch_a prev %s", pdu_tuple)
txn.execute(get_query, pdu_tuple)
prev_branch = branch_a prev_branch = branch_a
res = txn.fetchone() logger.debug("getting branch_a prev %s", pdu_tuple)
branch_a = PduEntry(*res) if res else None branch_a = self._get_pdu_tuple(txn, *pdu_tuple)
if branch_a:
branch_a = Pdu.from_pdu_tuple(branch_a)
logger.debug("branch_a=%s", branch_a) logger.debug("branch_a=%s", branch_a)
@ -820,14 +763,13 @@ class StatePduStore(SQLBaseStore):
branch_b.prev_state_id, branch_b.prev_state_id,
branch_b.prev_state_origin branch_b.prev_state_origin
) )
txn.execute(get_query, pdu_tuple)
logger.debug("getting branch_b prev %s", pdu_tuple)
prev_branch = branch_b prev_branch = branch_b
res = txn.fetchone() logger.debug("getting branch_b prev %s", pdu_tuple)
branch_b = PduEntry(*res) if res else None branch_b = self._get_pdu_tuple(txn, *pdu_tuple)
if branch_b:
branch_b = Pdu.from_pdu_tuple(branch_b)
logger.debug("branch_b=%s", branch_b) logger.debug("branch_b=%s", branch_b)

View file

@ -62,7 +62,7 @@ class RegistrationStore(SQLBaseStore):
Raises: Raises:
StoreError if the user_id could not be registered. StoreError if the user_id could not be registered.
""" """
yield self._db_pool.runInteraction(self._register, user_id, token, yield self.runInteraction(self._register, user_id, token,
password_hash) password_hash)
def _register(self, txn, user_id, token, password_hash): def _register(self, txn, user_id, token, password_hash):
@ -99,7 +99,7 @@ class RegistrationStore(SQLBaseStore):
Raises: Raises:
StoreError if no user was found. StoreError if no user was found.
""" """
user_id = yield self._db_pool.runInteraction(self._query_for_auth, user_id = yield self.runInteraction(self._query_for_auth,
token) token)
defer.returnValue(user_id) defer.returnValue(user_id)

View file

@ -149,7 +149,7 @@ class RoomStore(SQLBaseStore):
defer.returnValue(None) defer.returnValue(None)
def get_power_level(self, room_id, user_id): def get_power_level(self, room_id, user_id):
return self._db_pool.runInteraction( return self.runInteraction(
self._get_power_level, self._get_power_level,
room_id, user_id, room_id, user_id,
) )
@ -182,7 +182,7 @@ class RoomStore(SQLBaseStore):
return None return None
def get_ops_levels(self, room_id): def get_ops_levels(self, room_id):
return self._db_pool.runInteraction( return self.runInteraction(
self._get_ops_levels, self._get_ops_levels,
room_id, room_id,
) )

View file

@ -18,6 +18,7 @@ from twisted.internet import defer
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.util.logutils import log_function
import logging import logging
@ -29,8 +30,18 @@ class RoomMemberStore(SQLBaseStore):
def _store_room_member_txn(self, txn, event): def _store_room_member_txn(self, txn, event):
"""Store a room member in the database. """Store a room member in the database.
""" """
target_user_id = event.state_key try:
domain = self.hs.parse_userid(target_user_id).domain target_user_id = event.state_key
domain = self.hs.parse_userid(target_user_id).domain
except:
logger.exception("Failed to parse target_user_id=%s", target_user_id)
raise
logger.debug(
"_store_room_member_txn: target_user_id=%s, membership=%s",
target_user_id,
event.membership,
)
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
@ -51,12 +62,30 @@ class RoomMemberStore(SQLBaseStore):
"VALUES (?, ?)" "VALUES (?, ?)"
) )
txn.execute(sql, (event.room_id, domain)) txn.execute(sql, (event.room_id, domain))
else: elif event.membership != Membership.INVITE:
sql = ( # Check if this was the last person to have left.
"DELETE FROM room_hosts WHERE room_id = ? AND host = ?" member_events = self._get_members_query_txn(
txn,
where_clause="c.room_id = ? AND m.membership = ? AND m.user_id != ?",
where_values=(event.room_id, Membership.JOIN, target_user_id,)
) )
txn.execute(sql, (event.room_id, domain)) joined_domains = set()
for e in member_events:
try:
joined_domains.add(
self.hs.parse_userid(e.state_key).domain
)
except:
# FIXME: How do we deal with invalid user ids in the db?
logger.exception("Invalid user_id: %s", event.state_key)
if domain not in joined_domains:
sql = (
"DELETE FROM room_hosts WHERE room_id = ? AND host = ?"
)
txn.execute(sql, (event.room_id, domain))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_member(self, user_id, room_id): def get_room_member(self, user_id, room_id):
@ -88,7 +117,7 @@ class RoomMemberStore(SQLBaseStore):
txn.execute(sql, (user_id, room_id)) txn.execute(sql, (user_id, room_id))
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
if rows: if rows:
return self._parse_event_from_row(rows[0]) return self._parse_events_txn(txn, rows)[0]
else: else:
return None return None
@ -120,7 +149,7 @@ class RoomMemberStore(SQLBaseStore):
membership_list (list): A list of synapse.api.constants.Membership membership_list (list): A list of synapse.api.constants.Membership
values which the user must be in. values which the user must be in.
Returns: Returns:
A list of dicts with "room_id" and "membership" keys. A list of RoomMemberEvent objects
""" """
if not membership_list: if not membership_list:
return defer.succeed(None) return defer.succeed(None)
@ -146,8 +175,13 @@ class RoomMemberStore(SQLBaseStore):
vals = where_dict.values() vals = where_dict.values()
return self._get_members_query(clause, vals) return self._get_members_query(clause, vals)
@defer.inlineCallbacks
def _get_members_query(self, where_clause, where_values): def _get_members_query(self, where_clause, where_values):
return self._db_pool.runInteraction(
self._get_members_query_txn,
where_clause, where_values
)
def _get_members_query_txn(self, txn, where_clause, where_values):
sql = ( sql = (
"SELECT e.* FROM events as e " "SELECT e.* FROM events as e "
"INNER JOIN room_memberships as m " "INNER JOIN room_memberships as m "
@ -157,18 +191,18 @@ class RoomMemberStore(SQLBaseStore):
"WHERE %s " "WHERE %s "
) % (where_clause,) ) % (where_clause,)
rows = yield self._execute_and_decode(sql, *where_values) txn.execute(sql, where_values)
rows = self.cursor_to_dict(txn)
# logger.debug("_get_members_query Got rows %s", rows) results = self._parse_events_txn(txn, rows)
return results
results = [self._parse_event_from_row(r) for r in rows]
defer.returnValue(results)
@defer.inlineCallbacks @defer.inlineCallbacks
def user_rooms_intersect(self, user_list): def user_rooms_intersect(self, user_id_list):
""" Checks whether a list of users share a room. """ Checks whether all the users whose IDs are given in a list share a
room.
""" """
user_list_clause = " OR ".join(["m.user_id = ?"] * len(user_list)) user_list_clause = " OR ".join(["m.user_id = ?"] * len(user_id_list))
sql = ( sql = (
"SELECT m.room_id FROM room_memberships as m " "SELECT m.room_id FROM room_memberships as m "
"INNER JOIN current_state_events as c " "INNER JOIN current_state_events as c "
@ -178,8 +212,8 @@ class RoomMemberStore(SQLBaseStore):
"GROUP BY m.room_id HAVING COUNT(m.room_id) = ?" "GROUP BY m.room_id HAVING COUNT(m.room_id) = ?"
) % {"clause": user_list_clause} ) % {"clause": user_list_clause}
args = user_list args = list(user_id_list)
args.append(len(user_list)) args.append(len(user_id_list))
rows = yield self._execute(None, sql, *args) rows = yield self._execute(None, sql, *args)

View file

@ -0,0 +1,27 @@
/* Copyright 2014 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX IF NOT EXISTS room_aliases_alias ON room_aliases(room_alias);
CREATE INDEX IF NOT EXISTS room_aliases_id ON room_aliases(room_id);
CREATE INDEX IF NOT EXISTS room_alias_servers_alias ON room_alias_servers(room_alias);
DELETE FROM room_aliases WHERE rowid NOT IN (SELECT max(rowid) FROM room_aliases GROUP BY room_alias, room_id);
CREATE UNIQUE INDEX IF NOT EXISTS room_aliases_uniq ON room_aliases(room_alias, room_id);
PRAGMA user_version = 3;

View file

@ -146,7 +146,7 @@ class StreamStore(SQLBaseStore):
current_room_membership_sql = ( current_room_membership_sql = (
"SELECT m.room_id FROM room_memberships as m " "SELECT m.room_id FROM room_memberships as m "
"INNER JOIN current_state_events as c ON m.event_id = c.event_id " "INNER JOIN current_state_events as c ON m.event_id = c.event_id "
"WHERE m.user_id = ?" "WHERE m.user_id = ? AND m.membership = 'join'"
) )
# We also want to get any membership events about that user, e.g. # We also want to get any membership events about that user, e.g.
@ -188,7 +188,7 @@ class StreamStore(SQLBaseStore):
user_id, user_id, from_id, to_id user_id, user_id, from_id, to_id
) )
ret = [self._parse_event_from_row(r) for r in rows] ret = yield self._parse_events(rows)
if rows: if rows:
key = "s%d" % max([r["stream_ordering"] for r in rows]) key = "s%d" % max([r["stream_ordering"] for r in rows])
@ -243,9 +243,11 @@ class StreamStore(SQLBaseStore):
# TODO (erikj): We should work out what to do here instead. # TODO (erikj): We should work out what to do here instead.
next_token = to_key if to_key else from_key next_token = to_key if to_key else from_key
events = yield self._parse_events(rows)
defer.returnValue( defer.returnValue(
( (
[self._parse_event_from_row(r) for r in rows], events,
next_token next_token
) )
) )
@ -277,15 +279,14 @@ class StreamStore(SQLBaseStore):
else: else:
token = (end_token, end_token) token = (end_token, end_token)
defer.returnValue( events = yield self._parse_events(rows)
(
[self._parse_event_from_row(r) for r in rows], ret = (events, token)
token
) defer.returnValue(ret)
)
def get_room_events_max_id(self): def get_room_events_max_id(self):
return self._db_pool.runInteraction(self._get_room_events_max_id_txn) return self.runInteraction(self._get_room_events_max_id_txn)
def _get_room_events_max_id_txn(self, txn): def _get_room_events_max_id_txn(self, txn):
txn.execute( txn.execute(

View file

@ -41,7 +41,7 @@ class TransactionStore(SQLBaseStore):
this transaction or a 2-tuple of (int, dict) this transaction or a 2-tuple of (int, dict)
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._get_received_txn_response, transaction_id, origin self._get_received_txn_response, transaction_id, origin
) )
@ -72,7 +72,7 @@ class TransactionStore(SQLBaseStore):
response_json (str) response_json (str)
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._set_received_txn_response, self._set_received_txn_response,
transaction_id, origin, code, response_dict transaction_id, origin, code, response_dict
) )
@ -104,7 +104,7 @@ class TransactionStore(SQLBaseStore):
list: A list of previous transaction ids. list: A list of previous transaction ids.
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._prep_send_transaction, self._prep_send_transaction,
transaction_id, destination, ts, pdu_list transaction_id, destination, ts, pdu_list
) )
@ -159,7 +159,7 @@ class TransactionStore(SQLBaseStore):
code (int) code (int)
response_json (str) response_json (str)
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._delivered_txn, self._delivered_txn,
transaction_id, destination, code, response_dict transaction_id, destination, code, response_dict
) )
@ -184,7 +184,7 @@ class TransactionStore(SQLBaseStore):
Returns: Returns:
list: A list of `ReceivedTransactionsTable.EntryType` list: A list of `ReceivedTransactionsTable.EntryType`
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._get_transactions_after, transaction_id, destination self._get_transactions_after, transaction_id, destination
) )
@ -214,7 +214,7 @@ class TransactionStore(SQLBaseStore):
Returns Returns
list: A list of PduTuple list: A list of PduTuple
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._get_pdus_after_transaction, self._get_pdus_after_transaction,
transaction_id, destination transaction_id, destination
) )

View file

@ -0,0 +1,71 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" This module allows you to send out emails.
"""
import email.utils
import smtplib
import twisted.python.log
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
import logging
logger = logging.getLogger(__name__)
class EmailException(Exception):
pass
def send_email(smtp_server, from_addr, to_addr, subject, body):
"""Sends an email.
Args:
smtp_server(str): The SMTP server to use.
from_addr(str): The address to send from.
to_addr(str): The address to send to.
subject(str): The subject of the email.
body(str): The plain text body of the email.
Raises:
EmailException if there was a problem sending the mail.
"""
if not smtp_server or not from_addr or not to_addr:
raise EmailException("Need SMTP server, from and to addresses. Check " +
"the config to set these.")
msg = MIMEMultipart('alternative')
msg['Subject'] = subject
msg['From'] = from_addr
msg['To'] = to_addr
plain_part = MIMEText(body)
msg.attach(plain_part)
raw_from = email.utils.parseaddr(from_addr)[1]
raw_to = email.utils.parseaddr(to_addr)[1]
if not raw_from or not raw_to:
raise EmailException("Couldn't parse from/to address.")
logger.info("Sending email to %s on server %s with subject %s",
to_addr, smtp_server, subject)
try:
smtp = smtplib.SMTP(smtp_server)
smtp.sendmail(raw_from, raw_to, msg.as_string())
smtp.quit()
except Exception as origException:
twisted.python.log.err()
ese = EmailException()
ese.cause = origException
raise ese

9
synctl
View file

@ -4,7 +4,6 @@ SYNAPSE="synapse/app/homeserver.py"
CONFIGFILE="homeserver.yaml" CONFIGFILE="homeserver.yaml"
PIDFILE="homeserver.pid" PIDFILE="homeserver.pid"
LOGFILE="homeserver.log"
GREEN=$'\e[1;32m' GREEN=$'\e[1;32m'
NORMAL=$'\e[m' NORMAL=$'\e[m'
@ -14,15 +13,13 @@ set -e
case "$1" in case "$1" in
start) start)
if [ ! -f "$CONFIGFILE" ]; then if [ ! -f "$CONFIGFILE" ]; then
echo "No config file found - generating a default one..." echo "No config file found"
$SYNAPSE -c "$CONFIGFILE" --generate-config echo "To generate a config file, run 'python --generate-config'"
echo "Wrote $CONFIGFILE"
echo "You must now edit this file before continuing"
exit 1 exit 1
fi fi
echo -n "Starting ..." echo -n "Starting ..."
$SYNAPSE --daemonize -c "$CONFIGFILE" --pid-file "$PIDFILE" --log-file "$LOGFILE" $SYNAPSE --daemonize -c "$CONFIGFILE" --pid-file "$PIDFILE"
echo "${GREEN}started${NORMAL}" echo "${GREEN}started${NORMAL}"
;; ;;
stop) stop)

View file

@ -1,6 +1,6 @@
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
import unittest from tests import unittest
class TestRatelimiter(unittest.TestCase): class TestRatelimiter(unittest.TestCase):

View file

@ -15,7 +15,7 @@
from synapse.api.events import SynapseEvent from synapse.api.events import SynapseEvent
import unittest from tests import unittest
class SynapseTemplateCheckTestCase(unittest.TestCase): class SynapseTemplateCheckTestCase(unittest.TestCase):

View file

@ -14,11 +14,10 @@
# trial imports # trial imports
from twisted.internet import defer from twisted.internet import defer
from twisted.trial import unittest from tests import unittest
# python imports # python imports
from mock import Mock from mock import Mock, ANY
import logging
from ..utils import MockHttpResource, MockClock from ..utils import MockHttpResource, MockClock
@ -28,9 +27,6 @@ from synapse.federation.units import Pdu
from synapse.storage.pdu import PduTuple, PduEntry from synapse.storage.pdu import PduTuple, PduEntry
logging.getLogger().addHandler(logging.NullHandler())
def make_pdu(prev_pdus=[], **kwargs): def make_pdu(prev_pdus=[], **kwargs):
"""Provide some default fields for making a PduTuple.""" """Provide some default fields for making a PduTuple."""
pdu_fields = { pdu_fields = {
@ -185,7 +181,8 @@ class FederationTestCase(unittest.TestCase):
"depth": 1, "depth": 1,
}, },
] ]
} },
on_send_callback=ANY,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -216,7 +213,9 @@ class FederationTestCase(unittest.TestCase):
"content": {"testing": "content here"}, "content": {"testing": "content here"},
} }
], ],
}) },
on_send_callback=ANY,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_recv_edu(self): def test_recv_edu(self):

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.trial import unittest from tests import unittest
from synapse.federation.pdu_codec import ( from synapse.federation.pdu_codec import (
PduCodec, encode_event_id, decode_event_id PduCodec, encode_event_id, decode_event_id

View file

@ -14,19 +14,17 @@
# limitations under the License. # limitations under the License.
from twisted.trial import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from mock import Mock from mock import Mock
import logging
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.http.client import HttpClient from synapse.http.client import HttpClient
from synapse.handlers.directory import DirectoryHandler from synapse.handlers.directory import DirectoryHandler
from synapse.storage.directory import RoomAliasMapping from synapse.storage.directory import RoomAliasMapping
from tests.utils import SQLiteMemoryDbPool
logging.getLogger().addHandler(logging.NullHandler())
class DirectoryHandlers(object): class DirectoryHandlers(object):
@ -37,6 +35,7 @@ class DirectoryHandlers(object):
class DirectoryTestCase(unittest.TestCase): class DirectoryTestCase(unittest.TestCase):
""" Tests the directory service. """ """ Tests the directory service. """
@defer.inlineCallbacks
def setUp(self): def setUp(self):
self.mock_federation = Mock(spec=[ self.mock_federation = Mock(spec=[
"make_query", "make_query",
@ -47,11 +46,11 @@ class DirectoryTestCase(unittest.TestCase):
self.query_handlers[query_type] = handler self.query_handlers[query_type] = handler
self.mock_federation.register_query_handler = register_query_handler self.mock_federation.register_query_handler = register_query_handler
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer("test", hs = HomeServer("test",
datastore=Mock(spec=[ db_pool=db_pool,
"get_association_from_room_alias",
"get_joined_hosts_for_room",
]),
http_client=None, http_client=None,
resource_for_federation=Mock(), resource_for_federation=Mock(),
replication_layer=self.mock_federation, replication_layer=self.mock_federation,
@ -60,20 +59,16 @@ class DirectoryTestCase(unittest.TestCase):
self.handler = hs.get_handlers().directory_handler self.handler = hs.get_handlers().directory_handler
self.datastore = hs.get_datastore() self.store = hs.get_datastore()
def hosts(room_id):
return defer.succeed([])
self.datastore.get_joined_hosts_for_room.side_effect = hosts
self.my_room = hs.parse_roomalias("#my-room:test") self.my_room = hs.parse_roomalias("#my-room:test")
self.your_room = hs.parse_roomalias("#your-room:test")
self.remote_room = hs.parse_roomalias("#another:remote") self.remote_room = hs.parse_roomalias("#another:remote")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_local_association(self): def test_get_local_association(self):
mocked_get = self.datastore.get_association_from_room_alias yield self.store.create_room_alias_association(
mocked_get.return_value = defer.succeed( self.my_room, "!8765qwer:test", ["test"]
RoomAliasMapping("!8765qwer:test", "#my-room:test", ["test"])
) )
result = yield self.handler.get_association(self.my_room) result = yield self.handler.get_association(self.my_room)
@ -106,9 +101,8 @@ class DirectoryTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_incoming_fed_query(self): def test_incoming_fed_query(self):
mocked_get = self.datastore.get_association_from_room_alias yield self.store.create_room_alias_association(
mocked_get.return_value = defer.succeed( self.your_room, "!8765asdf:test", ["test"]
RoomAliasMapping("!8765asdf:test", "#your-room:test", ["test"])
) )
response = yield self.query_handlers["directory"]( response = yield self.query_handlers["directory"](

View file

@ -14,7 +14,7 @@
from twisted.internet import defer from twisted.internet import defer
from twisted.trial import unittest from tests import unittest
from synapse.api.events.room import ( from synapse.api.events.room import (
InviteJoinEvent, MessageEvent, RoomMemberEvent InviteJoinEvent, MessageEvent, RoomMemberEvent
@ -26,12 +26,8 @@ from synapse.federation.units import Pdu
from mock import NonCallableMock, ANY from mock import NonCallableMock, ANY
import logging
from ..utils import get_mock_call_args from ..utils import get_mock_call_args
logging.getLogger().addHandler(logging.NullHandler())
class FederationTestCase(unittest.TestCase): class FederationTestCase(unittest.TestCase):
@ -78,7 +74,9 @@ class FederationTestCase(unittest.TestCase):
yield self.handlers.federation_handler.on_receive_pdu(pdu, False) yield self.handlers.federation_handler.on_receive_pdu(pdu, False)
self.datastore.persist_event.assert_called_once_with(ANY, False) self.datastore.persist_event.assert_called_once_with(
ANY, False, is_new_state=False
)
self.notifier.on_new_room_event.assert_called_once_with(ANY) self.notifier.on_new_room_event.assert_called_once_with(ANY)
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -14,14 +14,15 @@
# limitations under the License. # limitations under the License.
from twisted.trial import unittest from tests import unittest
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from mock import Mock, call, ANY from mock import Mock, call, ANY
import logging
import json import json
from ..utils import MockHttpResource, MockClock, DeferredMockCallable from tests.utils import (
MockHttpResource, MockClock, DeferredMockCallable, SQLiteMemoryDbPool
)
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
@ -34,9 +35,6 @@ UNAVAILABLE = PresenceState.UNAVAILABLE
ONLINE = PresenceState.ONLINE ONLINE = PresenceState.ONLINE
logging.getLogger().addHandler(logging.NullHandler())
def _expect_edu(destination, edu_type, content, origin="test"): def _expect_edu(destination, edu_type, content, origin="test"):
return { return {
"origin": origin, "origin": origin,
@ -64,41 +62,36 @@ class JustPresenceHandlers(object):
class PresenceStateTestCase(unittest.TestCase): class PresenceStateTestCase(unittest.TestCase):
""" Tests presence management. """ """ Tests presence management. """
@defer.inlineCallbacks
def setUp(self): def setUp(self):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer("test", hs = HomeServer("test",
clock=MockClock(), clock=MockClock(),
db_pool=None, db_pool=db_pool,
datastore=Mock(spec=[ handlers=None,
"get_presence_state", resource_for_federation=Mock(),
"set_presence_state", http_client=None,
"add_presence_list_pending", )
"set_presence_list_accepted",
]),
handlers=None,
resource_for_federation=Mock(),
http_client=None,
)
hs.handlers = JustPresenceHandlers(hs) hs.handlers = JustPresenceHandlers(hs)
self.datastore = hs.get_datastore() self.store = hs.get_datastore()
def is_presence_visible(observed_localpart, observer_userid):
allow = (observed_localpart == "apple" and
observer_userid == "@banana:test"
)
return defer.succeed(allow)
self.datastore.is_presence_visible = is_presence_visible
# Mock the RoomMemberHandler # Mock the RoomMemberHandler
room_member_handler = Mock(spec=[]) room_member_handler = Mock(spec=[])
hs.handlers.room_member_handler = room_member_handler hs.handlers.room_member_handler = room_member_handler
logging.getLogger().debug("Mocking room_member_handler=%r", room_member_handler)
# Some local users to test with # Some local users to test with
self.u_apple = hs.parse_userid("@apple:test") self.u_apple = hs.parse_userid("@apple:test")
self.u_banana = hs.parse_userid("@banana:test") self.u_banana = hs.parse_userid("@banana:test")
self.u_clementine = hs.parse_userid("@clementine:test") self.u_clementine = hs.parse_userid("@clementine:test")
yield self.store.create_presence(self.u_apple.localpart)
yield self.store.set_presence_state(
self.u_apple.localpart, {"state": ONLINE, "status_msg": "Online"}
)
self.handler = hs.get_handlers().presence_handler self.handler = hs.get_handlers().presence_handler
self.room_members = [] self.room_members = []
@ -122,7 +115,7 @@ class PresenceStateTestCase(unittest.TestCase):
shared = all(map(lambda i: i in room_member_ids, userlist)) shared = all(map(lambda i: i in room_member_ids, userlist))
return defer.succeed(shared) return defer.succeed(shared)
self.datastore.user_rooms_intersect = user_rooms_intersect self.store.user_rooms_intersect = user_rooms_intersect
self.mock_start = Mock() self.mock_start = Mock()
self.mock_stop = Mock() self.mock_stop = Mock()
@ -132,11 +125,6 @@ class PresenceStateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_my_state(self): def test_get_my_state(self):
mocked_get = self.datastore.get_presence_state
mocked_get.return_value = defer.succeed(
{"state": ONLINE, "status_msg": "Online"}
)
state = yield self.handler.get_state( state = yield self.handler.get_state(
target_user=self.u_apple, auth_user=self.u_apple target_user=self.u_apple, auth_user=self.u_apple
) )
@ -145,13 +133,12 @@ class PresenceStateTestCase(unittest.TestCase):
{"presence": ONLINE, "status_msg": "Online"}, {"presence": ONLINE, "status_msg": "Online"},
state state
) )
mocked_get.assert_called_with("apple")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_allowed_state(self): def test_get_allowed_state(self):
mocked_get = self.datastore.get_presence_state yield self.store.allow_presence_visible(
mocked_get.return_value = defer.succeed( observed_localpart=self.u_apple.localpart,
{"state": ONLINE, "status_msg": "Online"} observer_userid=self.u_banana.to_string(),
) )
state = yield self.handler.get_state( state = yield self.handler.get_state(
@ -162,15 +149,9 @@ class PresenceStateTestCase(unittest.TestCase):
{"presence": ONLINE, "status_msg": "Online"}, {"presence": ONLINE, "status_msg": "Online"},
state state
) )
mocked_get.assert_called_with("apple")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_same_room_state(self): def test_get_same_room_state(self):
mocked_get = self.datastore.get_presence_state
mocked_get.return_value = defer.succeed(
{"state": ONLINE, "status_msg": "Online"}
)
self.room_members = [self.u_apple, self.u_clementine] self.room_members = [self.u_apple, self.u_clementine]
state = yield self.handler.get_state( state = yield self.handler.get_state(
@ -184,11 +165,6 @@ class PresenceStateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_disallowed_state(self): def test_get_disallowed_state(self):
mocked_get = self.datastore.get_presence_state
mocked_get.return_value = defer.succeed(
{"state": ONLINE, "status_msg": "Online"}
)
self.room_members = [] self.room_members = []
yield self.assertFailure( yield self.assertFailure(
@ -200,16 +176,17 @@ class PresenceStateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_state(self): def test_set_my_state(self):
mocked_set = self.datastore.set_presence_state
mocked_set.return_value = defer.succeed({"state": OFFLINE})
yield self.handler.set_state( yield self.handler.set_state(
target_user=self.u_apple, auth_user=self.u_apple, target_user=self.u_apple, auth_user=self.u_apple,
state={"presence": UNAVAILABLE, "status_msg": "Away"}) state={"presence": UNAVAILABLE, "status_msg": "Away"})
mocked_set.assert_called_with("apple", self.assertEquals(
{"state": UNAVAILABLE, "status_msg": "Away"} {"state": UNAVAILABLE,
"status_msg": "Away",
"mtime": 1000000},
(yield self.store.get_presence_state(self.u_apple.localpart))
) )
self.mock_start.assert_called_with(self.u_apple, self.mock_start.assert_called_with(self.u_apple,
state={ state={
"presence": UNAVAILABLE, "presence": UNAVAILABLE,
@ -227,50 +204,34 @@ class PresenceStateTestCase(unittest.TestCase):
class PresenceInvitesTestCase(unittest.TestCase): class PresenceInvitesTestCase(unittest.TestCase):
""" Tests presence management. """ """ Tests presence management. """
@defer.inlineCallbacks
def setUp(self): def setUp(self):
self.mock_http_client = Mock(spec=[]) self.mock_http_client = Mock(spec=[])
self.mock_http_client.put_json = DeferredMockCallable() self.mock_http_client.put_json = DeferredMockCallable()
self.mock_federation_resource = MockHttpResource() self.mock_federation_resource = MockHttpResource()
hs = HomeServer("test", db_pool = SQLiteMemoryDbPool()
clock=MockClock(), yield db_pool.prepare()
db_pool=None,
datastore=Mock(spec=[
"has_presence_state",
"allow_presence_visible",
"add_presence_list_pending",
"set_presence_list_accepted",
"get_presence_list",
"del_presence_list",
# Bits that Federation needs hs = HomeServer("test",
"prep_send_transaction", clock=MockClock(),
"delivered_txn", db_pool=db_pool,
"get_received_txn_response", handlers=None,
"set_received_txn_response", resource_for_client=Mock(),
]), resource_for_federation=self.mock_federation_resource,
handlers=None, http_client=self.mock_http_client,
resource_for_client=Mock(), )
resource_for_federation=self.mock_federation_resource,
http_client=self.mock_http_client,
)
hs.handlers = JustPresenceHandlers(hs) hs.handlers = JustPresenceHandlers(hs)
self.datastore = hs.get_datastore() self.store = hs.get_datastore()
def has_presence_state(user_localpart):
return defer.succeed(
user_localpart in ("apple", "banana"))
self.datastore.has_presence_state = has_presence_state
def get_received_txn_response(*args):
return defer.succeed(None)
self.datastore.get_received_txn_response = get_received_txn_response
# Some local users to test with # Some local users to test with
self.u_apple = hs.parse_userid("@apple:test") self.u_apple = hs.parse_userid("@apple:test")
self.u_banana = hs.parse_userid("@banana:test") self.u_banana = hs.parse_userid("@banana:test")
yield self.store.create_presence(self.u_apple.localpart)
yield self.store.create_presence(self.u_banana.localpart)
# ID of a local user that does not exist # ID of a local user that does not exist
self.u_durian = hs.parse_userid("@durian:test") self.u_durian = hs.parse_userid("@durian:test")
@ -293,12 +254,16 @@ class PresenceInvitesTestCase(unittest.TestCase):
yield self.handler.send_invite( yield self.handler.send_invite(
observer_user=self.u_apple, observed_user=self.u_banana) observer_user=self.u_apple, observed_user=self.u_banana)
self.datastore.add_presence_list_pending.assert_called_with( self.assertEquals(
"apple", "@banana:test") [{"observed_user_id": "@banana:test", "accepted": 1}],
self.datastore.allow_presence_visible.assert_called_with( (yield self.store.get_presence_list(self.u_apple.localpart))
"banana", "@apple:test") )
self.datastore.set_presence_list_accepted.assert_called_with( self.assertTrue(
"apple", "@banana:test") (yield self.store.is_presence_visible(
observed_localpart=self.u_banana.localpart,
observer_userid=self.u_apple.to_string(),
))
)
self.mock_start.assert_called_with( self.mock_start.assert_called_with(
self.u_apple, target_user=self.u_banana) self.u_apple, target_user=self.u_banana)
@ -308,10 +273,10 @@ class PresenceInvitesTestCase(unittest.TestCase):
yield self.handler.send_invite( yield self.handler.send_invite(
observer_user=self.u_apple, observed_user=self.u_durian) observer_user=self.u_apple, observed_user=self.u_durian)
self.datastore.add_presence_list_pending.assert_called_with( self.assertEquals(
"apple", "@durian:test") [],
self.datastore.del_presence_list.assert_called_with( (yield self.store.get_presence_list(self.u_apple.localpart))
"apple", "@durian:test") )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invite_remote(self): def test_invite_remote(self):
@ -324,7 +289,8 @@ class PresenceInvitesTestCase(unittest.TestCase):
"observer_user": "@apple:test", "observer_user": "@apple:test",
"observed_user": "@cabbage:elsewhere", "observed_user": "@cabbage:elsewhere",
} }
) ),
on_send_callback=ANY,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -332,8 +298,10 @@ class PresenceInvitesTestCase(unittest.TestCase):
yield self.handler.send_invite( yield self.handler.send_invite(
observer_user=self.u_apple, observed_user=self.u_cabbage) observer_user=self.u_apple, observed_user=self.u_cabbage)
self.datastore.add_presence_list_pending.assert_called_with( self.assertEquals(
"apple", "@cabbage:elsewhere") [{"observed_user_id": "@cabbage:elsewhere", "accepted": 0}],
(yield self.store.get_presence_list(self.u_apple.localpart))
)
yield put_json.await_calls() yield put_json.await_calls()
@ -350,7 +318,8 @@ class PresenceInvitesTestCase(unittest.TestCase):
"observer_user": "@cabbage:elsewhere", "observer_user": "@cabbage:elsewhere",
"observed_user": "@apple:test", "observed_user": "@apple:test",
} }
) ),
on_send_callback=ANY,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -365,8 +334,12 @@ class PresenceInvitesTestCase(unittest.TestCase):
) )
) )
self.datastore.allow_presence_visible.assert_called_with( self.assertTrue(
"apple", "@cabbage:elsewhere") (yield self.store.is_presence_visible(
observed_localpart=self.u_apple.localpart,
observer_userid=self.u_cabbage.to_string(),
))
)
yield put_json.await_calls() yield put_json.await_calls()
@ -381,7 +354,8 @@ class PresenceInvitesTestCase(unittest.TestCase):
"observer_user": "@cabbage:elsewhere", "observer_user": "@cabbage:elsewhere",
"observed_user": "@durian:test", "observed_user": "@durian:test",
} }
) ),
on_send_callback=ANY,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -400,6 +374,11 @@ class PresenceInvitesTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_accepted_remote(self): def test_accepted_remote(self):
yield self.store.add_presence_list_pending(
observer_localpart=self.u_apple.localpart,
observed_userid=self.u_cabbage.to_string(),
)
yield self.mock_federation_resource.trigger("PUT", yield self.mock_federation_resource.trigger("PUT",
"/_matrix/federation/v1/send/1000000/", "/_matrix/federation/v1/send/1000000/",
_make_edu_json("elsewhere", "m.presence_accept", _make_edu_json("elsewhere", "m.presence_accept",
@ -410,14 +389,21 @@ class PresenceInvitesTestCase(unittest.TestCase):
) )
) )
self.datastore.set_presence_list_accepted.assert_called_with( self.assertEquals(
"apple", "@cabbage:elsewhere") [{"observed_user_id": "@cabbage:elsewhere", "accepted": 1}],
(yield self.store.get_presence_list(self.u_apple.localpart))
)
self.mock_start.assert_called_with( self.mock_start.assert_called_with(
self.u_apple, target_user=self.u_cabbage) self.u_apple, target_user=self.u_cabbage)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_denied_remote(self): def test_denied_remote(self):
yield self.store.add_presence_list_pending(
observer_localpart=self.u_apple.localpart,
observed_userid="@eggplant:elsewhere",
)
yield self.mock_federation_resource.trigger("PUT", yield self.mock_federation_resource.trigger("PUT",
"/_matrix/federation/v1/send/1000000/", "/_matrix/federation/v1/send/1000000/",
_make_edu_json("elsewhere", "m.presence_deny", _make_edu_json("elsewhere", "m.presence_deny",
@ -428,32 +414,65 @@ class PresenceInvitesTestCase(unittest.TestCase):
) )
) )
self.datastore.del_presence_list.assert_called_with( self.assertEquals(
"apple", "@eggplant:elsewhere") [],
(yield self.store.get_presence_list(self.u_apple.localpart))
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_drop_local(self): def test_drop_local(self):
yield self.handler.drop( yield self.store.add_presence_list_pending(
observer_user=self.u_apple, observed_user=self.u_banana) observer_localpart=self.u_apple.localpart,
observed_userid=self.u_banana.to_string(),
)
yield self.store.set_presence_list_accepted(
observer_localpart=self.u_apple.localpart,
observed_userid=self.u_banana.to_string(),
)
self.datastore.del_presence_list.assert_called_with( yield self.handler.drop(
"apple", "@banana:test") observer_user=self.u_apple,
observed_user=self.u_banana,
)
self.assertEquals(
[],
(yield self.store.get_presence_list(self.u_apple.localpart))
)
self.mock_stop.assert_called_with( self.mock_stop.assert_called_with(
self.u_apple, target_user=self.u_banana) self.u_apple, target_user=self.u_banana)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_drop_remote(self): def test_drop_remote(self):
yield self.handler.drop( yield self.store.add_presence_list_pending(
observer_user=self.u_apple, observed_user=self.u_cabbage) observer_localpart=self.u_apple.localpart,
observed_userid=self.u_cabbage.to_string(),
)
yield self.store.set_presence_list_accepted(
observer_localpart=self.u_apple.localpart,
observed_userid=self.u_cabbage.to_string(),
)
self.datastore.del_presence_list.assert_called_with( yield self.handler.drop(
"apple", "@cabbage:elsewhere") observer_user=self.u_apple,
observed_user=self.u_cabbage,
)
self.assertEquals(
[],
(yield self.store.get_presence_list(self.u_apple.localpart))
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_presence_list(self): def test_get_presence_list(self):
self.datastore.get_presence_list.return_value = defer.succeed( yield self.store.add_presence_list_pending(
[{"observed_user_id": "@banana:test"}] observer_localpart=self.u_apple.localpart,
observed_userid=self.u_banana.to_string(),
)
yield self.store.set_presence_list_accepted(
observer_localpart=self.u_apple.localpart,
observed_userid=self.u_banana.to_string(),
) )
presence = yield self.handler.get_presence_list( presence = yield self.handler.get_presence_list(
@ -461,29 +480,10 @@ class PresenceInvitesTestCase(unittest.TestCase):
self.assertEquals([ self.assertEquals([
{"observed_user": self.u_banana, {"observed_user": self.u_banana,
"presence": OFFLINE}, "presence": OFFLINE,
"accepted": 1},
], presence) ], presence)
self.datastore.get_presence_list.assert_called_with("apple",
accepted=None
)
self.datastore.get_presence_list.return_value = defer.succeed(
[{"observed_user_id": "@banana:test"}]
)
presence = yield self.handler.get_presence_list(
observer_user=self.u_apple, accepted=True
)
self.assertEquals([
{"observed_user": self.u_banana,
"presence": OFFLINE},
], presence)
self.datastore.get_presence_list.assert_called_with("apple",
accepted=True)
class PresencePushTestCase(unittest.TestCase): class PresencePushTestCase(unittest.TestCase):
""" Tests steady-state presence status updates. """ Tests steady-state presence status updates.
@ -770,7 +770,8 @@ class PresencePushTestCase(unittest.TestCase):
"last_active_ago": 0}, "last_active_ago": 0},
], ],
} }
) ),
on_send_callback=ANY,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -785,7 +786,8 @@ class PresencePushTestCase(unittest.TestCase):
"last_active_ago": 0}, "last_active_ago": 0},
], ],
} }
) ),
on_send_callback=ANY,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -911,6 +913,7 @@ class PresencePushTestCase(unittest.TestCase):
], ],
} }
), ),
on_send_callback=ANY,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -925,6 +928,7 @@ class PresencePushTestCase(unittest.TestCase):
], ],
} }
), ),
on_send_callback=ANY,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -954,6 +958,7 @@ class PresencePushTestCase(unittest.TestCase):
], ],
} }
), ),
on_send_callback=ANY,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -1150,6 +1155,7 @@ class PresencePollingTestCase(unittest.TestCase):
"poll": [ "@potato:remote" ], "poll": [ "@potato:remote" ],
}, },
), ),
on_send_callback=ANY,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -1162,6 +1168,7 @@ class PresencePollingTestCase(unittest.TestCase):
"push": [ {"user_id": "@clementine:test" }], "push": [ {"user_id": "@clementine:test" }],
}, },
), ),
on_send_callback=ANY,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -1190,6 +1197,7 @@ class PresencePollingTestCase(unittest.TestCase):
"push": [ {"user_id": "@fig:test" }], "push": [ {"user_id": "@fig:test" }],
}, },
), ),
on_send_callback=ANY,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -1222,6 +1230,7 @@ class PresencePollingTestCase(unittest.TestCase):
"unpoll": [ "@potato:remote" ], "unpoll": [ "@potato:remote" ],
}, },
), ),
on_send_callback=ANY,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -1253,6 +1262,7 @@ class PresencePollingTestCase(unittest.TestCase):
], ],
}, },
), ),
on_send_callback=ANY,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )

View file

@ -16,11 +16,10 @@
"""This file contains tests of the "presence-like" data that is shared between """This file contains tests of the "presence-like" data that is shared between
presence and profiles; namely, the displayname and avatar_url.""" presence and profiles; namely, the displayname and avatar_url."""
from twisted.trial import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from mock import Mock, call, ANY from mock import Mock, call, ANY
import logging
from ..utils import MockClock from ..utils import MockClock
@ -35,9 +34,6 @@ UNAVAILABLE = PresenceState.UNAVAILABLE
ONLINE = PresenceState.ONLINE ONLINE = PresenceState.ONLINE
logging.getLogger().addHandler(logging.NullHandler())
class MockReplication(object): class MockReplication(object):
def __init__(self): def __init__(self):
self.edu_handlers = {} self.edu_handlers = {}
@ -69,6 +65,8 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
"is_presence_visible", "is_presence_visible",
"set_profile_displayname", "set_profile_displayname",
"get_rooms_for_user_where_membership_is",
]), ]),
handlers=None, handlers=None,
resource_for_federation=Mock(), resource_for_federation=Mock(),
@ -136,6 +134,10 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
# Remote user # Remote user
self.u_potato = hs.parse_userid("@potato:remote") self.u_potato = hs.parse_userid("@potato:remote")
self.mock_get_joined = (
self.datastore.get_rooms_for_user_where_membership_is
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_state(self): def test_set_my_state(self):
self.presence_list = [ self.presence_list = [
@ -156,6 +158,11 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_push_local(self): def test_push_local(self):
def get_joined(*args):
return defer.succeed([])
self.mock_get_joined.side_effect = get_joined
self.presence_list = [ self.presence_list = [
{"observed_user_id": "@banana:test"}, {"observed_user_id": "@banana:test"},
{"observed_user_id": "@clementine:test"}, {"observed_user_id": "@clementine:test"},

View file

@ -14,18 +14,17 @@
# limitations under the License. # limitations under the License.
from twisted.trial import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from mock import Mock from mock import Mock
import logging
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.handlers.profile import ProfileHandler from synapse.handlers.profile import ProfileHandler
from synapse.api.constants import Membership
from tests.utils import SQLiteMemoryDbPool
logging.getLogger().addHandler(logging.NullHandler())
class ProfileHandlers(object): class ProfileHandlers(object):
@ -36,6 +35,7 @@ class ProfileHandlers(object):
class ProfileTestCase(unittest.TestCase): class ProfileTestCase(unittest.TestCase):
""" Tests profile management. """ """ Tests profile management. """
@defer.inlineCallbacks
def setUp(self): def setUp(self):
self.mock_federation = Mock(spec=[ self.mock_federation = Mock(spec=[
"make_query", "make_query",
@ -46,27 +46,26 @@ class ProfileTestCase(unittest.TestCase):
self.query_handlers[query_type] = handler self.query_handlers[query_type] = handler
self.mock_federation.register_query_handler = register_query_handler self.mock_federation.register_query_handler = register_query_handler
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer("test", hs = HomeServer("test",
db_pool=None, db_pool=db_pool,
http_client=None, http_client=None,
datastore=Mock(spec=[
"get_profile_displayname",
"set_profile_displayname",
"get_profile_avatar_url",
"set_profile_avatar_url",
]),
handlers=None, handlers=None,
resource_for_federation=Mock(), resource_for_federation=Mock(),
replication_layer=self.mock_federation, replication_layer=self.mock_federation,
) )
hs.handlers = ProfileHandlers(hs) hs.handlers = ProfileHandlers(hs)
self.datastore = hs.get_datastore() self.store = hs.get_datastore()
self.frank = hs.parse_userid("@1234ABCD:test") self.frank = hs.parse_userid("@1234ABCD:test")
self.bob = hs.parse_userid("@4567:test") self.bob = hs.parse_userid("@4567:test")
self.alice = hs.parse_userid("@alice:remote") self.alice = hs.parse_userid("@alice:remote")
yield self.store.create_profile(self.frank.localpart)
self.handler = hs.get_handlers().profile_handler self.handler = hs.get_handlers().profile_handler
# TODO(paul): Icky signal declarings.. booo # TODO(paul): Icky signal declarings.. booo
@ -74,22 +73,22 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_my_name(self): def test_get_my_name(self):
mocked_get = self.datastore.get_profile_displayname yield self.store.set_profile_displayname(
mocked_get.return_value = defer.succeed("Frank") self.frank.localpart, "Frank"
)
displayname = yield self.handler.get_displayname(self.frank) displayname = yield self.handler.get_displayname(self.frank)
self.assertEquals("Frank", displayname) self.assertEquals("Frank", displayname)
mocked_get.assert_called_with("1234ABCD")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_name(self): def test_set_my_name(self):
mocked_set = self.datastore.set_profile_displayname
mocked_set.return_value = defer.succeed(())
yield self.handler.set_displayname(self.frank, self.frank, "Frank Jr.") yield self.handler.set_displayname(self.frank, self.frank, "Frank Jr.")
mocked_set.assert_called_with("1234ABCD", "Frank Jr.") self.assertEquals(
(yield self.store.get_profile_displayname(self.frank.localpart)),
"Frank Jr."
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_name_noauth(self): def test_set_my_name_noauth(self):
@ -114,32 +113,31 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_incoming_fed_query(self): def test_incoming_fed_query(self):
mocked_get = self.datastore.get_profile_displayname yield self.store.create_profile("caroline")
mocked_get.return_value = defer.succeed("Caroline") yield self.store.set_profile_displayname("caroline", "Caroline")
response = yield self.query_handlers["profile"]( response = yield self.query_handlers["profile"](
{"user_id": "@caroline:test", "field": "displayname"} {"user_id": "@caroline:test", "field": "displayname"}
) )
self.assertEquals({"displayname": "Caroline"}, response) self.assertEquals({"displayname": "Caroline"}, response)
mocked_get.assert_called_with("caroline")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_my_avatar(self): def test_get_my_avatar(self):
mocked_get = self.datastore.get_profile_avatar_url yield self.store.set_profile_avatar_url(
mocked_get.return_value = defer.succeed("http://my.server/me.png") self.frank.localpart, "http://my.server/me.png"
)
avatar_url = yield self.handler.get_avatar_url(self.frank) avatar_url = yield self.handler.get_avatar_url(self.frank)
self.assertEquals("http://my.server/me.png", avatar_url) self.assertEquals("http://my.server/me.png", avatar_url)
mocked_get.assert_called_with("1234ABCD")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_avatar(self): def test_set_my_avatar(self):
mocked_set = self.datastore.set_profile_avatar_url
mocked_set.return_value = defer.succeed(())
yield self.handler.set_avatar_url(self.frank, self.frank, yield self.handler.set_avatar_url(self.frank, self.frank,
"http://my.server/pic.gif") "http://my.server/pic.gif")
mocked_set.assert_called_with("1234ABCD", "http://my.server/pic.gif") self.assertEquals(
(yield self.store.get_profile_avatar_url(self.frank.localpart)),
"http://my.server/pic.gif"
)

View file

@ -15,7 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from twisted.trial import unittest from tests import unittest
from synapse.api.events.room import ( from synapse.api.events.room import (
InviteJoinEvent, RoomMemberEvent, RoomConfigEvent InviteJoinEvent, RoomMemberEvent, RoomConfigEvent
@ -27,10 +27,6 @@ from synapse.server import HomeServer
from mock import Mock, NonCallableMock from mock import Mock, NonCallableMock
import logging
logging.getLogger().addHandler(logging.NullHandler())
class RoomMemberHandlerTestCase(unittest.TestCase): class RoomMemberHandlerTestCase(unittest.TestCase):

View file

@ -14,12 +14,11 @@
# limitations under the License. # limitations under the License.
from twisted.trial import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from mock import Mock, call, ANY from mock import Mock, call, ANY
import json import json
import logging
from ..utils import MockHttpResource, MockClock, DeferredMockCallable from ..utils import MockHttpResource, MockClock, DeferredMockCallable
@ -27,9 +26,6 @@ from synapse.server import HomeServer
from synapse.handlers.typing import TypingNotificationHandler from synapse.handlers.typing import TypingNotificationHandler
logging.getLogger().addHandler(logging.NullHandler())
def _expect_edu(destination, edu_type, content, origin="test"): def _expect_edu(destination, edu_type, content, origin="test"):
return { return {
"origin": origin, "origin": origin,
@ -173,7 +169,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
"user_id": self.u_apple.to_string(), "user_id": self.u_apple.to_string(),
"typing": True, "typing": True,
} }
) ),
on_send_callback=ANY,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -223,7 +220,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
"user_id": self.u_apple.to_string(), "user_id": self.u_apple.to_string(),
"typing": False, "typing": False,
} }
) ),
on_send_callback=ANY,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
""" Tests REST events for /events paths.""" """ Tests REST events for /events paths."""
from twisted.trial import unittest from tests import unittest
# twisted imports # twisted imports
from twisted.internet import defer from twisted.internet import defer
@ -27,14 +27,12 @@ from synapse.server import HomeServer
# python imports # python imports
import json import json
import logging
from ..utils import MockHttpResource, MemoryDataStore from ..utils import MockHttpResource, MemoryDataStore
from .utils import RestTestCase from .utils import RestTestCase
from mock import Mock, NonCallableMock from mock import Mock, NonCallableMock
logging.getLogger().addHandler(logging.NullHandler())
PATH_PREFIX = "/_matrix/client/api/v1" PATH_PREFIX = "/_matrix/client/api/v1"
@ -145,6 +143,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)
hs.config.enable_registration_captcha = False
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()

View file

@ -15,11 +15,10 @@
"""Tests REST events for /presence paths.""" """Tests REST events for /presence paths."""
from twisted.trial import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from mock import Mock from mock import Mock
import logging
from ..utils import MockHttpResource from ..utils import MockHttpResource
@ -28,9 +27,6 @@ from synapse.handlers.presence import PresenceHandler
from synapse.server import HomeServer from synapse.server import HomeServer
logging.getLogger().addHandler(logging.NullHandler())
OFFLINE = PresenceState.OFFLINE OFFLINE = PresenceState.OFFLINE
UNAVAILABLE = PresenceState.UNAVAILABLE UNAVAILABLE = PresenceState.UNAVAILABLE
ONLINE = PresenceState.ONLINE ONLINE = PresenceState.ONLINE

View file

@ -15,7 +15,7 @@
"""Tests REST events for /profile paths.""" """Tests REST events for /profile paths."""
from twisted.trial import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from mock import Mock from mock import Mock
@ -28,6 +28,7 @@ from synapse.server import HomeServer
myid = "@1234ABCD:test" myid = "@1234ABCD:test"
PATH_PREFIX = "/_matrix/client/api/v1" PATH_PREFIX = "/_matrix/client/api/v1"
class ProfileTestCase(unittest.TestCase): class ProfileTestCase(unittest.TestCase):
""" Tests profile management. """ """ Tests profile management. """

View file

@ -17,7 +17,7 @@
from twisted.internet import defer from twisted.internet import defer
# trial imports # trial imports
from twisted.trial import unittest from tests import unittest
from synapse.api.constants import Membership from synapse.api.constants import Membership
@ -95,8 +95,14 @@ class RestTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def register(self, user_id): def register(self, user_id):
(code, response) = yield self.mock_resource.trigger("POST", "/register", (code, response) = yield self.mock_resource.trigger(
'{"user_id":"%s"}' % user_id) "POST",
"/register",
json.dumps({
"user": user_id,
"password": "test",
"type": "m.login.password"
}))
self.assertEquals(200, code) self.assertEquals(200, code)
defer.returnValue(response) defer.returnValue(response)

View file

@ -0,0 +1,5 @@
synapse/storage/feedback.py
synapse/storage/keys.py
synapse/storage/pdu.py
synapse/storage/stream.py
synapse/storage/transactions.py

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from twisted.trial import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from mock import Mock, call from mock import Mock, call

View file

@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.server import HomeServer
from synapse.storage.directory import DirectoryStore
from tests.utils import SQLiteMemoryDbPool
class DirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer("test",
db_pool=db_pool,
)
self.store = DirectoryStore(hs)
self.room = hs.parse_roomid("!abcde:test")
self.alias = hs.parse_roomalias("#my-room:test")
@defer.inlineCallbacks
def test_room_to_alias(self):
yield self.store.create_room_alias_association(
room_alias=self.alias,
room_id=self.room.to_string(),
servers=["test"],
)
self.assertEquals(
["#my-room:test"],
(yield self.store.get_aliases_for_room(self.room.to_string()))
)
@defer.inlineCallbacks
def test_alias_to_room(self):
yield self.store.create_room_alias_association(
room_alias=self.alias,
room_id=self.room.to_string(),
servers=["test"],
)
self.assertObjectHasAttributes(
{"room_id": self.room.to_string(),
"servers": ["test"]},
(yield self.store.get_association_from_room_alias(self.alias))
)

View file

@ -0,0 +1,167 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.server import HomeServer
from synapse.storage.presence import PresenceStore
from tests.utils import SQLiteMemoryDbPool, MockClock
class PresenceStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer("test",
clock=MockClock(),
db_pool=db_pool,
)
self.store = PresenceStore(hs)
self.u_apple = hs.parse_userid("@apple:test")
self.u_banana = hs.parse_userid("@banana:test")
@defer.inlineCallbacks
def test_state(self):
yield self.store.create_presence(
self.u_apple.localpart
)
state = yield self.store.get_presence_state(
self.u_apple.localpart
)
self.assertEquals(
{"state": None, "status_msg": None, "mtime": None}, state
)
yield self.store.set_presence_state(
self.u_apple.localpart, {"state": "online", "status_msg": "Here"}
)
state = yield self.store.get_presence_state(
self.u_apple.localpart
)
self.assertEquals(
{"state": "online", "status_msg": "Here", "mtime": 1000000}, state
)
@defer.inlineCallbacks
def test_visibility(self):
self.assertFalse((yield self.store.is_presence_visible(
observed_localpart=self.u_apple.localpart,
observer_userid=self.u_banana.to_string(),
)))
yield self.store.allow_presence_visible(
observed_localpart=self.u_apple.localpart,
observer_userid=self.u_banana.to_string(),
)
self.assertTrue((yield self.store.is_presence_visible(
observed_localpart=self.u_apple.localpart,
observer_userid=self.u_banana.to_string(),
)))
yield self.store.disallow_presence_visible(
observed_localpart=self.u_apple.localpart,
observer_userid=self.u_banana.to_string(),
)
self.assertFalse((yield self.store.is_presence_visible(
observed_localpart=self.u_apple.localpart,
observer_userid=self.u_banana.to_string(),
)))
@defer.inlineCallbacks
def test_presence_list(self):
self.assertEquals(
[],
(yield self.store.get_presence_list(
observer_localpart=self.u_apple.localpart,
))
)
self.assertEquals(
[],
(yield self.store.get_presence_list(
observer_localpart=self.u_apple.localpart,
accepted=True,
))
)
yield self.store.add_presence_list_pending(
observer_localpart=self.u_apple.localpart,
observed_userid=self.u_banana.to_string(),
)
self.assertEquals(
[{"observed_user_id": "@banana:test", "accepted": 0}],
(yield self.store.get_presence_list(
observer_localpart=self.u_apple.localpart,
))
)
self.assertEquals(
[],
(yield self.store.get_presence_list(
observer_localpart=self.u_apple.localpart,
accepted=True,
))
)
yield self.store.set_presence_list_accepted(
observer_localpart=self.u_apple.localpart,
observed_userid=self.u_banana.to_string(),
)
self.assertEquals(
[{"observed_user_id": "@banana:test", "accepted": 1}],
(yield self.store.get_presence_list(
observer_localpart=self.u_apple.localpart,
))
)
self.assertEquals(
[{"observed_user_id": "@banana:test", "accepted": 1}],
(yield self.store.get_presence_list(
observer_localpart=self.u_apple.localpart,
accepted=True,
))
)
yield self.store.del_presence_list(
observer_localpart=self.u_apple.localpart,
observed_userid=self.u_banana.to_string(),
)
self.assertEquals(
[],
(yield self.store.get_presence_list(
observer_localpart=self.u_apple.localpart,
))
)
self.assertEquals(
[],
(yield self.store.get_presence_list(
observer_localpart=self.u_apple.localpart,
accepted=True,
))
)

View file

@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.server import HomeServer
from synapse.storage.profile import ProfileStore
from tests.utils import SQLiteMemoryDbPool
class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer("test",
db_pool=db_pool,
)
self.store = ProfileStore(hs)
self.u_frank = hs.parse_userid("@frank:test")
@defer.inlineCallbacks
def test_displayname(self):
yield self.store.create_profile(
self.u_frank.localpart
)
yield self.store.set_profile_displayname(
self.u_frank.localpart, "Frank"
)
self.assertEquals(
"Frank",
(yield self.store.get_profile_displayname(self.u_frank.localpart))
)
@defer.inlineCallbacks
def test_avatar_url(self):
yield self.store.create_profile(
self.u_frank.localpart
)
yield self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here"
)
self.assertEquals(
"http://my.site/here",
(yield self.store.get_profile_avatar_url(self.u_frank.localpart))
)

View file

@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.server import HomeServer
from synapse.storage.registration import RegistrationStore
from tests.utils import SQLiteMemoryDbPool
class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer("test",
db_pool=db_pool,
)
self.store = RegistrationStore(hs)
self.user_id = "@my-user:test"
self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz",
"BcDeFgHiJkLmNoPqRsTuVwXyZa"]
self.pwhash = "{xx1}123456789"
@defer.inlineCallbacks
def test_register(self):
yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
self.assertEquals(
# TODO(paul): Surely this field should be 'user_id', not 'name'
# Additionally surely it shouldn't come in a 1-element list
[{"name": self.user_id, "password_hash": self.pwhash}],
(yield self.store.get_user_by_id(self.user_id))
)
self.assertEquals(
self.user_id,
(yield self.store.get_user_by_token(self.tokens[0]))
)
@defer.inlineCallbacks
def test_add_tokens(self):
yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
yield self.store.add_access_token_to_user(self.user_id, self.tokens[1])
self.assertEquals(
self.user_id,
(yield self.store.get_user_by_token(self.tokens[1]))
)

176
tests/storage/test_room.py Normal file
View file

@ -0,0 +1,176 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.server import HomeServer
from synapse.api.events.room import (
RoomNameEvent, RoomTopicEvent
)
from tests.utils import SQLiteMemoryDbPool
class RoomStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer("test",
db_pool=db_pool,
)
# We can't test RoomStore on its own without the DirectoryStore, for
# management of the 'room_aliases' table
self.store = hs.get_datastore()
self.room = hs.parse_roomid("!abcde:test")
self.alias = hs.parse_roomalias("#a-room-name:test")
self.u_creator = hs.parse_userid("@creator:test")
yield self.store.store_room(self.room.to_string(),
room_creator_user_id=self.u_creator.to_string(),
is_public=True
)
@defer.inlineCallbacks
def test_get_room(self):
self.assertObjectHasAttributes(
{"room_id": self.room.to_string(),
"creator": self.u_creator.to_string(),
"is_public": True},
(yield self.store.get_room(self.room.to_string()))
)
@defer.inlineCallbacks
def test_store_room_config(self):
yield self.store.store_room_config(self.room.to_string(),
visibility=False
)
self.assertObjectHasAttributes(
{"is_public": False},
(yield self.store.get_room(self.room.to_string()))
)
@defer.inlineCallbacks
def test_get_rooms(self):
# get_rooms does an INNER JOIN on the room_aliases table :(
rooms = yield self.store.get_rooms(is_public=True)
# Should be empty before we add the alias
self.assertEquals([], rooms)
yield self.store.create_room_alias_association(
room_alias=self.alias,
room_id=self.room.to_string(),
servers=["test"]
)
rooms = yield self.store.get_rooms(is_public=True)
self.assertEquals(1, len(rooms))
self.assertEquals({
"name": None,
"room_id": self.room.to_string(),
"topic": None,
"aliases": [self.alias.to_string()],
}, rooms[0])
class RoomEventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer("test",
db_pool=db_pool,
)
# Room events need the full datastore, for persist_event() and
# get_room_state()
self.store = hs.get_datastore()
self.event_factory = hs.get_event_factory();
self.room = hs.parse_roomid("!abcde:test")
yield self.store.store_room(self.room.to_string(),
room_creator_user_id="@creator:text",
is_public=True
)
@defer.inlineCallbacks
def inject_room_event(self, **kwargs):
yield self.store.persist_event(
self.event_factory.create_event(
room_id=self.room.to_string(),
**kwargs
)
)
@defer.inlineCallbacks
def test_room_name(self):
name = u"A-Room-Name"
yield self.inject_room_event(
etype=RoomNameEvent.TYPE,
name=name,
content={"name": name},
depth=1,
)
state = yield self.store.get_current_state(
room_id=self.room.to_string()
)
self.assertEquals(1, len(state))
self.assertObjectHasAttributes(
{"type": "m.room.name",
"room_id": self.room.to_string(),
"name": name},
state[0]
)
@defer.inlineCallbacks
def test_room_name(self):
topic = u"A place for things"
yield self.inject_room_event(
etype=RoomTopicEvent.TYPE,
topic=topic,
content={"topic": topic},
depth=1,
)
state = yield self.store.get_current_state(
room_id=self.room.to_string()
)
self.assertEquals(1, len(state))
self.assertObjectHasAttributes(
{"type": "m.room.topic",
"room_id": self.room.to_string(),
"topic": topic},
state[0]
)
# Not testing the various 'level' methods for now because there's lots
# of them and need coalescing; see JIRA SPEC-11

View file

@ -0,0 +1,157 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.server import HomeServer
from synapse.api.constants import Membership
from synapse.api.events.room import RoomMemberEvent
from tests.utils import SQLiteMemoryDbPool
class RoomMemberStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer("test",
db_pool=db_pool,
)
# We can't test the RoomMemberStore on its own without the other event
# storage logic
self.store = hs.get_datastore()
self.event_factory = hs.get_event_factory()
self.u_alice = hs.parse_userid("@alice:test")
self.u_bob = hs.parse_userid("@bob:test")
# User elsewhere on another host
self.u_charlie = hs.parse_userid("@charlie:elsewhere")
self.room = hs.parse_roomid("!abc123:test")
@defer.inlineCallbacks
def inject_room_member(self, room, user, membership):
# Have to create a join event using the eventfactory
yield self.store.persist_event(
self.event_factory.create_event(
etype=RoomMemberEvent.TYPE,
user_id=user.to_string(),
state_key=user.to_string(),
room_id=room.to_string(),
membership=membership,
content={"membership": membership},
depth=1,
)
)
@defer.inlineCallbacks
def test_one_member(self):
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
self.assertEquals(
Membership.JOIN,
(yield self.store.get_room_member(
user_id=self.u_alice.to_string(),
room_id=self.room.to_string(),
)).membership
)
self.assertEquals(
[self.u_alice.to_string()],
[m.user_id for m in (
yield self.store.get_room_members(self.room.to_string())
)]
)
self.assertEquals(
[self.room.to_string()],
[m.room_id for m in (
yield self.store.get_rooms_for_user_where_membership_is(
self.u_alice.to_string(), [Membership.JOIN]
))
]
)
self.assertFalse(
(yield self.store.user_rooms_intersect(
[self.u_alice.to_string(), self.u_bob.to_string()]
))
)
@defer.inlineCallbacks
def test_two_members(self):
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
yield self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
self.assertEquals(
{self.u_alice.to_string(), self.u_bob.to_string()},
{m.user_id for m in (
yield self.store.get_room_members(self.room.to_string())
)}
)
self.assertTrue(
(yield self.store.user_rooms_intersect(
[self.u_alice.to_string(), self.u_bob.to_string()]
))
)
@defer.inlineCallbacks
def test_room_hosts(self):
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
self.assertEquals(
["test"],
(yield self.store.get_joined_hosts_for_room(self.room.to_string()))
)
# Should still have just one host after second join from it
yield self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
self.assertEquals(
["test"],
(yield self.store.get_joined_hosts_for_room(self.room.to_string()))
)
# Should now have two hosts after join from other host
yield self.inject_room_member(self.room, self.u_charlie, Membership.JOIN)
self.assertEquals(
{"test", "elsewhere"},
set((yield
self.store.get_joined_hosts_for_room(self.room.to_string())
))
)
# Should still have both hosts
yield self.inject_room_member(self.room, self.u_alice, Membership.LEAVE)
self.assertEquals(
{"test", "elsewhere"},
set((yield
self.store.get_joined_hosts_for_room(self.room.to_string())
))
)
# Should have only one host after other leaves
yield self.inject_room_member(self.room, self.u_charlie, Membership.LEAVE)
self.assertEquals(
["test"],
(yield self.store.get_joined_hosts_for_room(self.room.to_string()))
)

View file

@ -0,0 +1,226 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.server import HomeServer
from synapse.api.constants import Membership
from synapse.api.events.room import RoomMemberEvent, MessageEvent
from tests.utils import SQLiteMemoryDbPool
class StreamStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer(
"test",
db_pool=db_pool,
)
self.store = hs.get_datastore()
self.event_factory = hs.get_event_factory()
self.u_alice = hs.parse_userid("@alice:test")
self.u_bob = hs.parse_userid("@bob:test")
self.room1 = hs.parse_roomid("!abc123:test")
self.room2 = hs.parse_roomid("!xyx987:test")
self.depth = 1
@defer.inlineCallbacks
def inject_room_member(self, room, user, membership, prev_state=None):
self.depth += 1
event = self.event_factory.create_event(
etype=RoomMemberEvent.TYPE,
user_id=user.to_string(),
state_key=user.to_string(),
room_id=room.to_string(),
membership=membership,
content={"membership": membership},
depth=self.depth,
)
if prev_state:
event.prev_state = prev_state
# Have to create a join event using the eventfactory
yield self.store.persist_event(
event
)
defer.returnValue(event)
@defer.inlineCallbacks
def inject_message(self, room, user, body):
self.depth += 1
# Have to create a join event using the eventfactory
yield self.store.persist_event(
self.event_factory.create_event(
etype=MessageEvent.TYPE,
user_id=user.to_string(),
room_id=room.to_string(),
content={"body": body, "msgtype": u"message"},
depth=self.depth,
)
)
@defer.inlineCallbacks
def test_event_stream_get_other(self):
# Both bob and alice joins the room
yield self.inject_room_member(
self.room1, self.u_alice, Membership.JOIN
)
yield self.inject_room_member(
self.room1, self.u_bob, Membership.JOIN
)
# Initial stream key:
start = yield self.store.get_room_events_max_id()
yield self.inject_message(self.room1, self.u_alice, u"test")
end = yield self.store.get_room_events_max_id()
results, _ = yield self.store.get_room_events_stream(
self.u_bob.to_string(),
start,
end,
None, # Is currently ignored
)
self.assertEqual(1, len(results))
event = results[0]
self.assertObjectHasAttributes(
{
"type": MessageEvent.TYPE,
"user_id": self.u_alice.to_string(),
"content": {"body": "test", "msgtype": "message"},
},
event,
)
@defer.inlineCallbacks
def test_event_stream_get_own(self):
# Both bob and alice joins the room
yield self.inject_room_member(
self.room1, self.u_alice, Membership.JOIN
)
yield self.inject_room_member(
self.room1, self.u_bob, Membership.JOIN
)
# Initial stream key:
start = yield self.store.get_room_events_max_id()
yield self.inject_message(self.room1, self.u_alice, u"test")
end = yield self.store.get_room_events_max_id()
results, _ = yield self.store.get_room_events_stream(
self.u_alice.to_string(),
start,
end,
None, # Is currently ignored
)
self.assertEqual(1, len(results))
event = results[0]
self.assertObjectHasAttributes(
{
"type": MessageEvent.TYPE,
"user_id": self.u_alice.to_string(),
"content": {"body": "test", "msgtype": "message"},
},
event,
)
@defer.inlineCallbacks
def test_event_stream_join_leave(self):
# Both bob and alice joins the room
yield self.inject_room_member(
self.room1, self.u_alice, Membership.JOIN
)
yield self.inject_room_member(
self.room1, self.u_bob, Membership.JOIN
)
# Then bob leaves again.
yield self.inject_room_member(
self.room1, self.u_bob, Membership.LEAVE
)
# Initial stream key:
start = yield self.store.get_room_events_max_id()
yield self.inject_message(self.room1, self.u_alice, u"test")
end = yield self.store.get_room_events_max_id()
results, _ = yield self.store.get_room_events_stream(
self.u_bob.to_string(),
start,
end,
None, # Is currently ignored
)
# We should not get the message, as it happened *after* bob left.
self.assertEqual(0, len(results))
@defer.inlineCallbacks
def test_event_stream_prev_content(self):
yield self.inject_room_member(
self.room1, self.u_bob, Membership.JOIN
)
event1 = yield self.inject_room_member(
self.room1, self.u_alice, Membership.JOIN
)
start = yield self.store.get_room_events_max_id()
event2 = yield self.inject_room_member(
self.room1, self.u_alice, Membership.JOIN,
prev_state=event1.event_id,
)
end = yield self.store.get_room_events_max_id()
results, _ = yield self.store.get_room_events_stream(
self.u_bob.to_string(),
start,
end,
None, # Is currently ignored
)
# We should not get the message, as it happened *after* bob left.
self.assertEqual(1, len(results))
event = results[0]
self.assertTrue(hasattr(event, "prev_content"), msg="No prev_content key")

View file

@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from twisted.trial import unittest
from mock import Mock, patch from mock import Mock, patch

View file

@ -13,23 +13,32 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from twisted.trial import unittest from twisted.python.log import PythonLoggingObserver
from synapse.state import StateHandler from synapse.state import StateHandler
from synapse.storage.pdu import PduEntry from synapse.storage.pdu import PduEntry
from synapse.federation.pdu_codec import encode_event_id from synapse.federation.pdu_codec import encode_event_id
from synapse.federation.units import Pdu
from collections import namedtuple from collections import namedtuple
from mock import Mock from mock import Mock
import mock
ReturnType = namedtuple( ReturnType = namedtuple(
"StateReturnType", ["new_branch", "current_branch"] "StateReturnType", ["new_branch", "current_branch"]
) )
def _gen_get_power_level(power_level_list):
def get_power_level(room_id, user_id):
return defer.succeed(power_level_list.get(user_id, None))
return get_power_level
class StateTestCase(unittest.TestCase): class StateTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.persistence = Mock(spec=[ self.persistence = Mock(spec=[
@ -38,6 +47,7 @@ class StateTestCase(unittest.TestCase):
"get_latest_pdus_in_context", "get_latest_pdus_in_context",
"get_current_state_pdu", "get_current_state_pdu",
"get_pdu", "get_pdu",
"get_power_level",
]) ])
self.replication = Mock(spec=["get_pdu"]) self.replication = Mock(spec=["get_pdu"])
@ -51,10 +61,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_new_state_key(self): def test_new_state_key(self):
# We've never seen anything for this state before # We've never seen anything for this state before
new_pdu = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) new_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u")
self.persistence.get_power_level.side_effect = _gen_get_power_level({})
self.persistence.get_unresolved_state_tree.return_value = ( self.persistence.get_unresolved_state_tree.return_value = (
ReturnType([new_pdu], []) (ReturnType([new_pdu], []), None)
) )
is_new = yield self.state.handle_new_state(new_pdu) is_new = yield self.state.handle_new_state(new_pdu)
@ -74,11 +86,44 @@ class StateTestCase(unittest.TestCase):
# We do a direct overwriting of the old state, i.e., the new state # We do a direct overwriting of the old state, i.e., the new state
# points to the old state. # points to the old state.
old_pdu = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) old_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u1")
new_pdu = new_fake_pdu_entry("B", "test", "mem", "x", "A", 5) new_pdu = new_fake_pdu("B", "test", "mem", "x", "A", "u2")
self.persistence.get_power_level.side_effect = _gen_get_power_level({
"u1": 10,
"u2": 5,
})
self.persistence.get_unresolved_state_tree.return_value = ( self.persistence.get_unresolved_state_tree.return_value = (
ReturnType([new_pdu, old_pdu], [old_pdu]) (ReturnType([new_pdu, old_pdu], [old_pdu]), None)
)
is_new = yield self.state.handle_new_state(new_pdu)
self.assertTrue(is_new)
self.persistence.get_unresolved_state_tree.assert_called_once_with(
new_pdu
)
self.assertEqual(1, self.persistence.update_current_state.call_count)
self.assertFalse(self.replication.get_pdu.called)
@defer.inlineCallbacks
def test_overwrite(self):
old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", "A", "u2")
new_pdu = new_fake_pdu("C", "test", "mem", "x", "B", "u3")
self.persistence.get_power_level.side_effect = _gen_get_power_level({
"u1": 10,
"u2": 5,
"u3": 0,
})
self.persistence.get_unresolved_state_tree.return_value = (
(ReturnType([new_pdu, old_pdu_2, old_pdu_1], [old_pdu_1]), None)
) )
is_new = yield self.state.handle_new_state(new_pdu) is_new = yield self.state.handle_new_state(new_pdu)
@ -98,12 +143,18 @@ class StateTestCase(unittest.TestCase):
# We try to update the state based on an outdated state, and have a # We try to update the state based on an outdated state, and have a
# too low power level. # too low power level.
old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2")
new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 5) new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3")
self.persistence.get_power_level.side_effect = _gen_get_power_level({
"u1": 10,
"u2": 10,
"u3": 5,
})
self.persistence.get_unresolved_state_tree.return_value = ( self.persistence.get_unresolved_state_tree.return_value = (
ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]) (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None)
) )
is_new = yield self.state.handle_new_state(new_pdu) is_new = yield self.state.handle_new_state(new_pdu)
@ -123,12 +174,18 @@ class StateTestCase(unittest.TestCase):
# We try to update the state based on an outdated state, but have # We try to update the state based on an outdated state, but have
# sufficient power level to force the update. # sufficient power level to force the update.
old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2")
new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 15) new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3")
self.persistence.get_power_level.side_effect = _gen_get_power_level({
"u1": 10,
"u2": 10,
"u3": 15,
})
self.persistence.get_unresolved_state_tree.return_value = ( self.persistence.get_unresolved_state_tree.return_value = (
ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]) (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None)
) )
is_new = yield self.state.handle_new_state(new_pdu) is_new = yield self.state.handle_new_state(new_pdu)
@ -148,12 +205,18 @@ class StateTestCase(unittest.TestCase):
# We try to update the state based on an outdated state, the power # We try to update the state based on an outdated state, the power
# levels are the same and so are the branch lengths # levels are the same and so are the branch lengths
old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2")
new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 10) new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3")
self.persistence.get_power_level.side_effect = _gen_get_power_level({
"u1": 10,
"u2": 10,
"u3": 10,
})
self.persistence.get_unresolved_state_tree.return_value = ( self.persistence.get_unresolved_state_tree.return_value = (
ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]) (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None)
) )
is_new = yield self.state.handle_new_state(new_pdu) is_new = yield self.state.handle_new_state(new_pdu)
@ -173,13 +236,26 @@ class StateTestCase(unittest.TestCase):
# We try to update the state based on an outdated state, the power # We try to update the state based on an outdated state, the power
# levels are the same but the branch length of the new one is longer. # levels are the same but the branch length of the new one is longer.
old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2")
old_pdu_3 = new_fake_pdu_entry("C", "test", "mem", "x", "A", 10) old_pdu_3 = new_fake_pdu("C", "test", "mem", "x", "A", "u3")
new_pdu = new_fake_pdu_entry("D", "test", "mem", "x", "C", 10) new_pdu = new_fake_pdu("D", "test", "mem", "x", "C", "u4")
self.persistence.get_power_level.side_effect = _gen_get_power_level({
"u1": 10,
"u2": 10,
"u3": 10,
"u4": 10,
})
self.persistence.get_unresolved_state_tree.return_value = ( self.persistence.get_unresolved_state_tree.return_value = (
ReturnType([new_pdu, old_pdu_3, old_pdu_1], [old_pdu_2, old_pdu_1]) (
ReturnType(
[new_pdu, old_pdu_3, old_pdu_1],
[old_pdu_2, old_pdu_1]
),
None
)
) )
is_new = yield self.state.handle_new_state(new_pdu) is_new = yield self.state.handle_new_state(new_pdu)
@ -200,22 +276,38 @@ class StateTestCase(unittest.TestCase):
# triggering a get_pdu request # triggering a get_pdu request
# The pdu we haven't seen # The pdu we haven't seen
old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) old_pdu_1 = new_fake_pdu(
"A", "test", "mem", "x", None, "u1", depth=0
)
old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) old_pdu_2 = new_fake_pdu(
new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 20) "B", "test", "mem", "x", "A", "u2", depth=1
)
new_pdu = new_fake_pdu(
"C", "test", "mem", "x", "A", "u3", depth=2
)
self.persistence.get_power_level.side_effect = _gen_get_power_level({
"u1": 10,
"u2": 10,
"u3": 20,
})
# The return_value of `get_unresolved_state_tree`, which changes after # The return_value of `get_unresolved_state_tree`, which changes after
# the call to get_pdu # the call to get_pdu
tree_to_return = [ReturnType([new_pdu], [old_pdu_2])] tree_to_return = [(ReturnType([new_pdu], [old_pdu_2]), 0)]
def return_tree(p): def return_tree(p):
return tree_to_return[0] return tree_to_return[0]
def set_return_tree(*args, **kwargs): def set_return_tree(destination, pdu_origin, pdu_id, outlier=False):
tree_to_return[0] = ReturnType( tree_to_return[0] = (
[new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1] ReturnType(
[new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]
),
None
) )
return defer.succeed(None)
self.persistence.get_unresolved_state_tree.side_effect = return_tree self.persistence.get_unresolved_state_tree.side_effect = return_tree
@ -227,6 +319,13 @@ class StateTestCase(unittest.TestCase):
self.assertTrue(is_new) self.assertTrue(is_new)
self.replication.get_pdu.assert_called_with(
destination=new_pdu.origin,
pdu_origin=old_pdu_1.origin,
pdu_id=old_pdu_1.pdu_id,
outlier=True
)
self.persistence.get_unresolved_state_tree.assert_called_with( self.persistence.get_unresolved_state_tree.assert_called_with(
new_pdu new_pdu
) )
@ -237,11 +336,233 @@ class StateTestCase(unittest.TestCase):
self.assertEqual(1, self.persistence.update_current_state.call_count) self.assertEqual(1, self.persistence.update_current_state.call_count)
@defer.inlineCallbacks
def test_missing_pdu_depth_1(self):
# We try to update state against a PDU we haven't yet seen,
# triggering a get_pdu request
# The pdu we haven't seen
old_pdu_1 = new_fake_pdu(
"A", "test", "mem", "x", None, "u1", depth=0
)
old_pdu_2 = new_fake_pdu(
"B", "test", "mem", "x", "A", "u2", depth=2
)
old_pdu_3 = new_fake_pdu(
"C", "test", "mem", "x", "B", "u3", depth=3
)
new_pdu = new_fake_pdu(
"D", "test", "mem", "x", "A", "u4", depth=4
)
self.persistence.get_power_level.side_effect = _gen_get_power_level({
"u1": 10,
"u2": 10,
"u3": 10,
"u4": 20,
})
# The return_value of `get_unresolved_state_tree`, which changes after
# the call to get_pdu
tree_to_return = [
(
ReturnType([new_pdu], [old_pdu_3]),
0
),
(
ReturnType(
[new_pdu, old_pdu_1], [old_pdu_3]
),
1
),
(
ReturnType(
[new_pdu, old_pdu_1], [old_pdu_3, old_pdu_2, old_pdu_1]
),
None
),
]
to_return = [0]
def return_tree(p):
return tree_to_return[to_return[0]]
def set_return_tree(destination, pdu_origin, pdu_id, outlier=False):
to_return[0] += 1
return defer.succeed(None)
self.persistence.get_unresolved_state_tree.side_effect = return_tree
self.replication.get_pdu.side_effect = set_return_tree
self.persistence.get_pdu.return_value = None
is_new = yield self.state.handle_new_state(new_pdu)
self.assertTrue(is_new)
self.assertEqual(2, self.replication.get_pdu.call_count)
self.replication.get_pdu.assert_has_calls(
[
mock.call(
destination=new_pdu.origin,
pdu_origin=old_pdu_1.origin,
pdu_id=old_pdu_1.pdu_id,
outlier=True
),
mock.call(
destination=old_pdu_3.origin,
pdu_origin=old_pdu_2.origin,
pdu_id=old_pdu_2.pdu_id,
outlier=True
),
]
)
self.persistence.get_unresolved_state_tree.assert_called_with(
new_pdu
)
self.assertEquals(
3, self.persistence.get_unresolved_state_tree.call_count
)
self.assertEqual(1, self.persistence.update_current_state.call_count)
@defer.inlineCallbacks
def test_missing_pdu_depth_2(self):
# We try to update state against a PDU we haven't yet seen,
# triggering a get_pdu request
# The pdu we haven't seen
old_pdu_1 = new_fake_pdu(
"A", "test", "mem", "x", None, "u1", depth=0
)
old_pdu_2 = new_fake_pdu(
"B", "test", "mem", "x", "A", "u2", depth=2
)
old_pdu_3 = new_fake_pdu(
"C", "test", "mem", "x", "B", "u3", depth=3
)
new_pdu = new_fake_pdu(
"D", "test", "mem", "x", "A", "u4", depth=1
)
self.persistence.get_power_level.side_effect = _gen_get_power_level({
"u1": 10,
"u2": 10,
"u3": 10,
"u4": 20,
})
# The return_value of `get_unresolved_state_tree`, which changes after
# the call to get_pdu
tree_to_return = [
(
ReturnType([new_pdu], [old_pdu_3]),
1,
),
(
ReturnType(
[new_pdu], [old_pdu_3, old_pdu_2]
),
0,
),
(
ReturnType(
[new_pdu, old_pdu_1], [old_pdu_3, old_pdu_2, old_pdu_1]
),
None
),
]
to_return = [0]
def return_tree(p):
return tree_to_return[to_return[0]]
def set_return_tree(destination, pdu_origin, pdu_id, outlier=False):
to_return[0] += 1
return defer.succeed(None)
self.persistence.get_unresolved_state_tree.side_effect = return_tree
self.replication.get_pdu.side_effect = set_return_tree
self.persistence.get_pdu.return_value = None
is_new = yield self.state.handle_new_state(new_pdu)
self.assertTrue(is_new)
self.assertEqual(2, self.replication.get_pdu.call_count)
self.replication.get_pdu.assert_has_calls(
[
mock.call(
destination=old_pdu_3.origin,
pdu_origin=old_pdu_2.origin,
pdu_id=old_pdu_2.pdu_id,
outlier=True
),
mock.call(
destination=new_pdu.origin,
pdu_origin=old_pdu_1.origin,
pdu_id=old_pdu_1.pdu_id,
outlier=True
),
]
)
self.persistence.get_unresolved_state_tree.assert_called_with(
new_pdu
)
self.assertEquals(
3, self.persistence.get_unresolved_state_tree.call_count
)
self.assertEqual(1, self.persistence.update_current_state.call_count)
@defer.inlineCallbacks
def test_no_common_ancestor(self):
# We do a direct overwriting of the old state, i.e., the new state
# points to the old state.
old_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u1")
new_pdu = new_fake_pdu("B", "test", "mem", "x", None, "u2")
self.persistence.get_power_level.side_effect = _gen_get_power_level({
"u1": 5,
"u2": 10,
})
self.persistence.get_unresolved_state_tree.return_value = (
(ReturnType([new_pdu], [old_pdu]), None)
)
is_new = yield self.state.handle_new_state(new_pdu)
self.assertTrue(is_new)
self.persistence.get_unresolved_state_tree.assert_called_once_with(
new_pdu
)
self.assertEqual(1, self.persistence.update_current_state.call_count)
self.assertFalse(self.replication.get_pdu.called)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_new_event(self): def test_new_event(self):
event = Mock() event = Mock()
event.event_id = "12123123@test"
state_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 20) state_pdu = new_fake_pdu("C", "test", "mem", "x", "A", 20)
snapshot = Mock() snapshot = Mock()
snapshot.prev_state_pdu = state_pdu snapshot.prev_state_pdu = state_pdu
@ -268,24 +589,25 @@ class StateTestCase(unittest.TestCase):
) )
def new_fake_pdu_entry(pdu_id, context, pdu_type, state_key, prev_state_id, def new_fake_pdu(pdu_id, context, pdu_type, state_key, prev_state_id,
power_level): user_id, depth=0):
new_pdu = PduEntry( new_pdu = Pdu(
pdu_id=pdu_id, pdu_id=pdu_id,
pdu_type=pdu_type, pdu_type=pdu_type,
state_key=state_key, state_key=state_key,
power_level=power_level, user_id=user_id,
prev_state_id=prev_state_id, prev_state_id=prev_state_id,
origin="example.com", origin="example.com",
context="context", context="context",
ts=1405353060021, ts=1405353060021,
depth=0, depth=depth,
content_json="{}", content_json="{}",
unrecognized_keys="{}", unrecognized_keys="{}",
outlier=True, outlier=True,
is_state=True, is_state=True,
prev_state_origin="example.com", prev_state_origin="example.com",
have_processed=True, have_processed=True,
content={},
) )
return new_pdu return new_pdu

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest from tests import unittest
from synapse.server import BaseHomeServer from synapse.server import BaseHomeServer
from synapse.types import UserID, RoomAlias from synapse.types import UserID, RoomAlias

90
tests/unittest.py Normal file
View file

@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.trial import unittest
import logging
# logging doesn't have a "don't log anything at all EVARRRR setting,
# but since the highest value is 50, 1000000 should do ;)
NEVER = 1000000
logging.getLogger().addHandler(logging.StreamHandler())
logging.getLogger().setLevel(NEVER)
def around(target):
"""A CLOS-style 'around' modifier, which wraps the original method of the
given instance with another piece of code.
@around(self)
def method_name(orig, *args, **kwargs):
return orig(*args, **kwargs)
"""
def _around(code):
name = code.__name__
orig = getattr(target, name)
def new(*args, **kwargs):
return code(orig, *args, **kwargs)
setattr(target, name, new)
return _around
class TestCase(unittest.TestCase):
"""A subclass of twisted.trial's TestCase which looks for 'loglevel'
attributes on both itself and its individual test methods, to override the
root logger's logging level while that test (case|method) runs."""
def __init__(self, methodName, *args, **kwargs):
super(TestCase, self).__init__(methodName, *args, **kwargs)
method = getattr(self, methodName)
level = getattr(method, "loglevel",
getattr(self, "loglevel",
NEVER))
@around(self)
def setUp(orig):
old_level = logging.getLogger().level
if old_level != level:
@around(self)
def tearDown(orig):
ret = orig()
logging.getLogger().setLevel(old_level)
return ret
logging.getLogger().setLevel(level)
return orig()
def assertObjectHasAttributes(self, attrs, obj):
"""Asserts that the given object has each of the attributes given, and
that the value of each matches according to assertEquals."""
for (key, value) in attrs.items():
if not hasattr(obj, key):
raise AssertionError("Expected obj to have a '.%s'" % key)
try:
self.assertEquals(attrs[key], getattr(obj, key))
except AssertionError as e:
raise (type(e))(e.message + " for '.%s'" % key)
def DEBUG(target):
"""A decorator to set the .loglevel attribute to logging.DEBUG.
Can apply to either a TestCase or an individual test method."""
target.loglevel = logging.DEBUG
return target

View file

@ -15,7 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from twisted.trial import unittest from tests import unittest
from synapse.util.lockutils import LockManager from synapse.util.lockutils import LockManager
@ -105,4 +105,4 @@ class LockManagerTestCase(unittest.TestCase):
pass pass
with (yield self.lock_manager.lock(key)): with (yield self.lock_manager.lock(key)):
pass pass

View file

@ -16,12 +16,14 @@
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.api.errors import cs_error, CodeMessageException, StoreError from synapse.api.errors import cs_error, CodeMessageException, StoreError
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.storage import prepare_database
from synapse.api.events.room import ( from synapse.api.events.room import (
RoomMemberEvent, MessageEvent RoomMemberEvent, MessageEvent
) )
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.enterprise.adbapi import ConnectionPool
from collections import namedtuple from collections import namedtuple
from mock import patch, Mock from mock import patch, Mock
@ -120,6 +122,18 @@ class MockClock(object):
self.now += secs self.now += secs
class SQLiteMemoryDbPool(ConnectionPool, object):
def __init__(self):
super(SQLiteMemoryDbPool, self).__init__(
"sqlite3", ":memory:",
cp_min=1,
cp_max=1,
)
def prepare(self):
return self.runWithConnection(prepare_database)
class MemoryDataStore(object): class MemoryDataStore(object):
Room = namedtuple( Room = namedtuple(

46
webclient/CAPTCHA_SETUP Normal file
View file

@ -0,0 +1,46 @@
Captcha can be enabled for this web client / home server. This file explains how to do that.
The captcha mechanism used is Google's ReCaptcha. This requires API keys from Google.
Getting keys
------------
Requires a public/private key pair from:
https://developers.google.com/recaptcha/
Setting Private ReCaptcha Key
-----------------------------
The private key is a config option on the home server config. If it is not
visible, you can generate it via --generate-config. Set the following value:
recaptcha_private_key: YOUR_PRIVATE_KEY
In addition, you MUST enable captchas via:
enable_registration_captcha: true
Setting Public ReCaptcha Key
----------------------------
The web client will look for the global variable webClientConfig for config
options. You should put your ReCaptcha public key there like so:
webClientConfig = {
useCaptcha: true,
recaptcha_public_key: "YOUR_PUBLIC_KEY"
}
This should be put in webclient/config.js which is already .gitignored, rather
than in the web client source files. You MUST set useCaptcha to true else a
ReCaptcha widget will not be generated.
Configuring IP used for auth
----------------------------
The ReCaptcha API requires that the IP address of the user who solved the
captcha is sent. If the client is connecting through a proxy or load balancer,
it may be required to use the X-Forwarded-For (XFF) header instead of the origin
IP address. This can be configured as an option on the home server like so:
captcha_ip_origin_is_x_forwarded: true

Some files were not shown because too many files have changed in this diff Show more