Refactor REST API tests to use explicit reactors (#3351)

This commit is contained in:
Amber Brown 2018-07-17 20:43:18 +10:00 committed by GitHub
parent c7320a5564
commit bc006b3c9d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 985 additions and 980 deletions

0
changelog.d/3351.misc Normal file
View file

View file

@ -42,9 +42,10 @@ class SynapseRequest(Request):
which is handling the request, and returns a context manager. which is handling the request, and returns a context manager.
""" """
def __init__(self, site, *args, **kw): def __init__(self, site, channel, *args, **kw):
Request.__init__(self, *args, **kw) Request.__init__(self, channel, *args, **kw)
self.site = site self.site = site
self._channel = channel
self.authenticated_entity = None self.authenticated_entity = None
self.start_time = 0 self.start_time = 0

View file

@ -643,7 +643,7 @@ class RegisterRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_guest_registration(self, params): def _do_guest_registration(self, params):
if not self.hs.config.allow_guest_access: if not self.hs.config.allow_guest_access:
defer.returnValue((403, "Guest access is disabled")) raise SynapseError(403, "Guest access is disabled")
user_id, _ = yield self.registration_handler.register( user_id, _ = yield self.registration_handler.register(
generate_token=False, generate_token=False,
make_guest=True make_guest=True

View file

@ -17,26 +17,22 @@ import json
from mock import Mock from mock import Mock
from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactorClock
from synapse.rest.client.v1.register import CreateUserRestServlet from synapse.http.server import JsonResource
from synapse.rest.client.v1.register import register_servlets
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.utils import mock_getRawHeaders from tests.server import make_request, setup_test_homeserver
class CreateUserServletTestCase(unittest.TestCase): class CreateUserServletTestCase(unittest.TestCase):
"""
Tests for CreateUserRestServlet.
"""
def setUp(self): def setUp(self):
# do the dance to hook up request data to self.request_data
self.request_data = ""
self.request = Mock(
content=Mock(read=Mock(side_effect=lambda: self.request_data)),
path='/_matrix/client/api/v1/createUser'
)
self.request.args = {}
self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.registration_handler = Mock() self.registration_handler = Mock()
self.appservice = Mock(sender="@as:test") self.appservice = Mock(sender="@as:test")
@ -44,39 +40,49 @@ class CreateUserServletTestCase(unittest.TestCase):
get_app_service_by_token=Mock(return_value=self.appservice) get_app_service_by_token=Mock(return_value=self.appservice)
) )
# do the dance to hook things up to the hs global handlers = Mock(registration_handler=self.registration_handler)
handlers = Mock( self.clock = MemoryReactorClock()
registration_handler=self.registration_handler, self.hs_clock = Clock(self.clock)
self.hs = self.hs = setup_test_homeserver(
http_client=None, clock=self.hs_clock, reactor=self.clock
) )
self.hs = Mock()
self.hs.hostname = "superbig~testing~thing.com"
self.hs.get_datastore = Mock(return_value=self.datastore) self.hs.get_datastore = Mock(return_value=self.datastore)
self.hs.get_handlers = Mock(return_value=handlers) self.hs.get_handlers = Mock(return_value=handlers)
self.servlet = CreateUserRestServlet(self.hs)
@defer.inlineCallbacks
def test_POST_createuser_with_valid_user(self): def test_POST_createuser_with_valid_user(self):
user_id = "@someone:interesting"
token = "my token" res = JsonResource(self.hs)
self.request.args = { register_servlets(self.hs, res)
"access_token": "i_am_an_app_service"
} request_data = json.dumps(
self.request_data = json.dumps({ {
"localpart": "someone", "localpart": "someone",
"displayname": "someone interesting", "displayname": "someone interesting",
"duration_seconds": 200 "duration_seconds": 200,
}) }
)
url = b'/_matrix/client/api/v1/createUser?access_token=i_am_an_app_service'
user_id = "@someone:interesting"
token = "my token"
self.registration_handler.get_or_create_user = Mock( self.registration_handler.get_or_create_user = Mock(
return_value=(user_id, token) return_value=(user_id, token)
) )
(code, result) = yield self.servlet.on_POST(self.request) request, channel = make_request(b"POST", url, request_data)
self.assertEquals(code, 200) request.render(res)
# Advance the clock because it waits
self.clock.advance(1)
self.assertEquals(channel.result["code"], b"200")
det_data = { det_data = {
"user_id": user_id, "user_id": user_id,
"access_token": token, "access_token": token,
"home_server": self.hs.hostname "home_server": self.hs.hostname,
} }
self.assertDictContainsSubset(det_data, result) self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))

File diff suppressed because it is too large Load diff

View file

@ -16,13 +16,14 @@
import json import json
import time import time
# twisted imports import attr
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership from synapse.api.constants import Membership
# trial imports
from tests import unittest from tests import unittest
from tests.server import make_request, wait_until_result
class RestTestCase(unittest.TestCase): class RestTestCase(unittest.TestCase):
@ -133,3 +134,113 @@ class RestTestCase(unittest.TestCase):
for key in required: for key in required:
self.assertEquals(required[key], actual[key], self.assertEquals(required[key], actual[key],
msg="%s mismatch. %s" % (key, actual)) msg="%s mismatch. %s" % (key, actual))
@attr.s
class RestHelper(object):
"""Contains extra helper functions to quickly and clearly perform a given
REST action, which isn't the focus of the test.
"""
hs = attr.ib()
resource = attr.ib()
auth_user_id = attr.ib()
def create_room_as(self, room_creator, is_public=True, tok=None):
temp_id = self.auth_user_id
self.auth_user_id = room_creator
path = b"/_matrix/client/r0/createRoom"
content = {}
if not is_public:
content["visibility"] = "private"
if tok:
path = path + b"?access_token=%s" % tok.encode('ascii')
request, channel = make_request(b"POST", path, json.dumps(content).encode('utf8'))
request.render(self.resource)
wait_until_result(self.hs.get_reactor(), channel)
assert channel.result["code"] == b"200", channel.result
self.auth_user_id = temp_id
return channel.json_body["room_id"]
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
self.change_membership(
room=room,
src=src,
targ=targ,
tok=tok,
membership=Membership.INVITE,
expect_code=expect_code,
)
def join(self, room=None, user=None, expect_code=200, tok=None):
self.change_membership(
room=room,
src=user,
targ=user,
tok=tok,
membership=Membership.JOIN,
expect_code=expect_code,
)
def leave(self, room=None, user=None, expect_code=200, tok=None):
self.change_membership(
room=room,
src=user,
targ=user,
tok=tok,
membership=Membership.LEAVE,
expect_code=expect_code,
)
def change_membership(self, room, src, targ, membership, tok=None, expect_code=200):
temp_id = self.auth_user_id
self.auth_user_id = src
path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room, targ)
if tok:
path = path + "?access_token=%s" % tok
data = {"membership": membership}
request, channel = make_request(
b"PUT", path.encode('ascii'), json.dumps(data).encode('utf8')
)
request.render(self.resource)
wait_until_result(self.hs.get_reactor(), channel)
assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r"
% (expect_code, int(channel.result["code"]), channel.result["body"])
)
self.auth_user_id = temp_id
@defer.inlineCallbacks
def register(self, user_id):
(code, response) = yield self.mock_resource.trigger(
"POST",
"/_matrix/client/r0/register",
json.dumps(
{"user": user_id, "password": "test", "type": "m.login.password"}
),
)
self.assertEquals(200, code)
defer.returnValue(response)
@defer.inlineCallbacks
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
if body is None:
body = "body_text_here"
path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
content = '{"msgtype":"m.text","body":"%s"}' % body
if tok:
path = path + "?access_token=%s" % tok
(code, response) = yield self.mock_resource.trigger("PUT", path, content)
self.assertEquals(expect_code, code, msg=str(response))

View file

@ -1,61 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock
from twisted.internet import defer
from synapse.types import UserID
from tests import unittest
from ....utils import MockHttpResource, setup_test_homeserver
PATH_PREFIX = "/_matrix/client/v2_alpha"
class V2AlphaRestTestCase(unittest.TestCase):
# Consumer must define
# USER_ID = <some string>
# TO_REGISTER = [<list of REST servlets to register>]
@defer.inlineCallbacks
def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
hs = yield setup_test_homeserver(
datastore=self.make_datastore_mock(),
http_client=None,
resource_for_client=self.mock_resource,
resource_for_federation=self.mock_resource,
)
def get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.USER_ID),
"token_id": 1,
"is_guest": False,
}
hs.get_auth().get_user_by_access_token = get_user_by_access_token
for r in self.TO_REGISTER:
r.register_servlets(hs, self.mock_resource)
def make_datastore_mock(self):
store = Mock(spec=[
"insert_client_ip",
])
store.get_app_service_by_token = Mock(return_value=None)
return store

View file

@ -13,35 +13,33 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
import synapse.types import synapse.types
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.http.server import JsonResource
from synapse.rest.client.v2_alpha import filter from synapse.rest.client.v2_alpha import filter
from synapse.types import UserID from synapse.types import UserID
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import ThreadedMemoryReactorClock as MemoryReactorClock
from ....utils import MockHttpResource, setup_test_homeserver from tests.server import make_request, setup_test_homeserver, wait_until_result
PATH_PREFIX = "/_matrix/client/v2_alpha" PATH_PREFIX = "/_matrix/client/v2_alpha"
class FilterTestCase(unittest.TestCase): class FilterTestCase(unittest.TestCase):
USER_ID = "@apple:test" USER_ID = b"@apple:test"
EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
EXAMPLE_FILTER_JSON = '{"room": {"timeline": {"types": ["m.room.message"]}}}' EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}'
TO_REGISTER = [filter] TO_REGISTER = [filter]
@defer.inlineCallbacks
def setUp(self): def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) self.clock = MemoryReactorClock()
self.hs_clock = Clock(self.clock)
self.hs = yield setup_test_homeserver( self.hs = setup_test_homeserver(
http_client=None, http_client=None, clock=self.hs_clock, reactor=self.clock
resource_for_client=self.mock_resource,
resource_for_federation=self.mock_resource,
) )
self.auth = self.hs.get_auth() self.auth = self.hs.get_auth()
@ -55,82 +53,103 @@ class FilterTestCase(unittest.TestCase):
def get_user_by_req(request, allow_guest=False, rights="access"): def get_user_by_req(request, allow_guest=False, rights="access"):
return synapse.types.create_requester( return synapse.types.create_requester(
UserID.from_string(self.USER_ID), 1, False, None) UserID.from_string(self.USER_ID), 1, False, None
)
self.auth.get_user_by_access_token = get_user_by_access_token self.auth.get_user_by_access_token = get_user_by_access_token
self.auth.get_user_by_req = get_user_by_req self.auth.get_user_by_req = get_user_by_req
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.filtering = self.hs.get_filtering() self.filtering = self.hs.get_filtering()
self.resource = JsonResource(self.hs)
for r in self.TO_REGISTER: for r in self.TO_REGISTER:
r.register_servlets(self.hs, self.mock_resource) r.register_servlets(self.hs, self.resource)
@defer.inlineCallbacks
def test_add_filter(self): def test_add_filter(self):
(code, response) = yield self.mock_resource.trigger( request, channel = make_request(
"POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON b"POST",
b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
self.EXAMPLE_FILTER_JSON,
) )
self.assertEquals(200, code) request.render(self.resource)
self.assertEquals({"filter_id": "0"}, response) wait_until_result(self.clock, channel)
filter = yield self.store.get_user_filter(
user_localpart='apple', self.assertEqual(channel.result["code"], b"200")
filter_id=0, self.assertEqual(channel.json_body, {"filter_id": "0"})
) filter = self.store.get_user_filter(user_localpart="apple", filter_id=0)
self.assertEquals(filter, self.EXAMPLE_FILTER) self.clock.advance(0)
self.assertEquals(filter.result, self.EXAMPLE_FILTER)
@defer.inlineCallbacks
def test_add_filter_for_other_user(self): def test_add_filter_for_other_user(self):
(code, response) = yield self.mock_resource.trigger( request, channel = make_request(
"POST", "/user/%s/filter" % ('@watermelon:test'), self.EXAMPLE_FILTER_JSON b"POST",
b"/_matrix/client/r0/user/%s/filter" % (b"@watermelon:test"),
self.EXAMPLE_FILTER_JSON,
) )
self.assertEquals(403, code) request.render(self.resource)
self.assertEquals(response['errcode'], Codes.FORBIDDEN) wait_until_result(self.clock, channel)
self.assertEqual(channel.result["code"], b"403")
self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
@defer.inlineCallbacks
def test_add_filter_non_local_user(self): def test_add_filter_non_local_user(self):
_is_mine = self.hs.is_mine _is_mine = self.hs.is_mine
self.hs.is_mine = lambda target_user: False self.hs.is_mine = lambda target_user: False
(code, response) = yield self.mock_resource.trigger( request, channel = make_request(
"POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON b"POST",
b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
self.EXAMPLE_FILTER_JSON,
) )
request.render(self.resource)
wait_until_result(self.clock, channel)
self.hs.is_mine = _is_mine self.hs.is_mine = _is_mine
self.assertEquals(403, code) self.assertEqual(channel.result["code"], b"403")
self.assertEquals(response['errcode'], Codes.FORBIDDEN) self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
@defer.inlineCallbacks
def test_get_filter(self): def test_get_filter(self):
filter_id = yield self.filtering.add_user_filter( filter_id = self.filtering.add_user_filter(
user_localpart='apple', user_localpart="apple", user_filter=self.EXAMPLE_FILTER
user_filter=self.EXAMPLE_FILTER
) )
(code, response) = yield self.mock_resource.trigger_get( self.clock.advance(1)
"/user/%s/filter/%s" % (self.USER_ID, filter_id) filter_id = filter_id.result
request, channel = make_request(
b"GET", b"/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id)
) )
self.assertEquals(200, code) request.render(self.resource)
self.assertEquals(self.EXAMPLE_FILTER, response) wait_until_result(self.clock, channel)
self.assertEqual(channel.result["code"], b"200")
self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
@defer.inlineCallbacks
def test_get_filter_non_existant(self): def test_get_filter_non_existant(self):
(code, response) = yield self.mock_resource.trigger_get( request, channel = make_request(
"/user/%s/filter/12382148321" % (self.USER_ID) b"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID)
) )
self.assertEquals(400, code) request.render(self.resource)
self.assertEquals(response['errcode'], Codes.NOT_FOUND) wait_until_result(self.clock, channel)
self.assertEqual(channel.result["code"], b"400")
self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
# Currently invalid params do not have an appropriate errcode # Currently invalid params do not have an appropriate errcode
# in errors.py # in errors.py
@defer.inlineCallbacks
def test_get_filter_invalid_id(self): def test_get_filter_invalid_id(self):
(code, response) = yield self.mock_resource.trigger_get( request, channel = make_request(
"/user/%s/filter/foobar" % (self.USER_ID) b"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID)
) )
self.assertEquals(400, code) request.render(self.resource)
wait_until_result(self.clock, channel)
self.assertEqual(channel.result["code"], b"400")
# No ID also returns an invalid_id error # No ID also returns an invalid_id error
@defer.inlineCallbacks
def test_get_filter_no_id(self): def test_get_filter_no_id(self):
(code, response) = yield self.mock_resource.trigger_get( request, channel = make_request(
"/user/%s/filter/" % (self.USER_ID) b"GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID)
) )
self.assertEquals(400, code) request.render(self.resource)
wait_until_result(self.clock, channel)
self.assertEqual(channel.result["code"], b"400")

View file

@ -2,165 +2,192 @@ import json
from mock import Mock from mock import Mock
from twisted.internet import defer
from twisted.python import failure from twisted.python import failure
from twisted.test.proto_helpers import MemoryReactorClock
from synapse.api.errors import InteractiveAuthIncompleteError, SynapseError from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.http.server import JsonResource
from synapse.rest.client.v2_alpha.register import register_servlets
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.utils import mock_getRawHeaders from tests.server import make_request, setup_test_homeserver, wait_until_result
class RegisterRestServletTestCase(unittest.TestCase): class RegisterRestServletTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
# do the dance to hook up request data to self.request_data
self.request_data = "" self.clock = MemoryReactorClock()
self.request = Mock( self.hs_clock = Clock(self.clock)
content=Mock(read=Mock(side_effect=lambda: self.request_data)), self.url = b"/_matrix/client/r0/register"
path='/_matrix/api/v2_alpha/register'
)
self.request.args = {}
self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.appservice = None self.appservice = None
self.auth = Mock(get_appservice_by_req=Mock( self.auth = Mock(
side_effect=lambda x: self.appservice) get_appservice_by_req=Mock(side_effect=lambda x: self.appservice)
) )
self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None)) self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
self.auth_handler = Mock( self.auth_handler = Mock(
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result), check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
get_session_data=Mock(return_value=None) get_session_data=Mock(return_value=None),
) )
self.registration_handler = Mock() self.registration_handler = Mock()
self.identity_handler = Mock() self.identity_handler = Mock()
self.login_handler = Mock() self.login_handler = Mock()
self.device_handler = Mock() self.device_handler = Mock()
self.device_handler.check_device_registered = Mock(return_value="FAKE")
self.datastore = Mock(return_value=Mock())
self.datastore.get_current_state_deltas = Mock(return_value=[])
# do the dance to hook it up to the hs global # do the dance to hook it up to the hs global
self.handlers = Mock( self.handlers = Mock(
registration_handler=self.registration_handler, registration_handler=self.registration_handler,
identity_handler=self.identity_handler, identity_handler=self.identity_handler,
login_handler=self.login_handler login_handler=self.login_handler,
)
self.hs = setup_test_homeserver(
http_client=None, clock=self.hs_clock, reactor=self.clock
) )
self.hs = Mock()
self.hs.hostname = "superbig~testing~thing.com"
self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_auth = Mock(return_value=self.auth)
self.hs.get_handlers = Mock(return_value=self.handlers) self.hs.get_handlers = Mock(return_value=self.handlers)
self.hs.get_auth_handler = Mock(return_value=self.auth_handler) self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
self.hs.get_device_handler = Mock(return_value=self.device_handler) self.hs.get_device_handler = Mock(return_value=self.device_handler)
self.hs.get_datastore = Mock(return_value=self.datastore)
self.hs.config.enable_registration = True self.hs.config.enable_registration = True
self.hs.config.registrations_require_3pid = [] self.hs.config.registrations_require_3pid = []
self.hs.config.auto_join_rooms = [] self.hs.config.auto_join_rooms = []
# init the thing we're testing self.resource = JsonResource(self.hs)
self.servlet = RegisterRestServlet(self.hs) register_servlets(self.hs, self.resource)
@defer.inlineCallbacks
def test_POST_appservice_registration_valid(self): def test_POST_appservice_registration_valid(self):
user_id = "@kermit:muppet" user_id = "@kermit:muppet"
token = "kermits_access_token" token = "kermits_access_token"
self.request.args = { self.appservice = {"id": "1234"}
"access_token": "i_am_an_app_service" self.registration_handler.appservice_register = Mock(return_value=user_id)
} self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
self.request_data = json.dumps({ request_data = json.dumps({"username": "kermit"})
"username": "kermit"
})
self.appservice = {
"id": "1234"
}
self.registration_handler.appservice_register = Mock(
return_value=user_id
)
self.auth_handler.get_access_token_for_user_id = Mock(
return_value=token
)
(code, result) = yield self.servlet.on_POST(self.request) request, channel = make_request(
self.assertEquals(code, 200) b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
request.render(self.resource)
wait_until_result(self.clock, channel)
self.assertEquals(channel.result["code"], b"200", channel.result)
det_data = { det_data = {
"user_id": user_id, "user_id": user_id,
"access_token": token, "access_token": token,
"home_server": self.hs.hostname "home_server": self.hs.hostname,
} }
self.assertDictContainsSubset(det_data, result) self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
@defer.inlineCallbacks
def test_POST_appservice_registration_invalid(self): def test_POST_appservice_registration_invalid(self):
self.request.args = {
"access_token": "i_am_an_app_service"
}
self.request_data = json.dumps({
"username": "kermit"
})
self.appservice = None # no application service exists self.appservice = None # no application service exists
result = yield self.servlet.on_POST(self.request) request_data = json.dumps({"username": "kermit"})
self.assertEquals(result, (401, None)) request, channel = make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
request.render(self.resource)
wait_until_result(self.clock, channel)
self.assertEquals(channel.result["code"], b"401", channel.result)
def test_POST_bad_password(self): def test_POST_bad_password(self):
self.request_data = json.dumps({ request_data = json.dumps({"username": "kermit", "password": 666})
"username": "kermit", request, channel = make_request(b"POST", self.url, request_data)
"password": 666 request.render(self.resource)
}) wait_until_result(self.clock, channel)
d = self.servlet.on_POST(self.request)
return self.assertFailure(d, SynapseError) self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(
json.loads(channel.result["body"])["error"], "Invalid password"
)
def test_POST_bad_username(self): def test_POST_bad_username(self):
self.request_data = json.dumps({ request_data = json.dumps({"username": 777, "password": "monkey"})
"username": 777, request, channel = make_request(b"POST", self.url, request_data)
"password": "monkey" request.render(self.resource)
}) wait_until_result(self.clock, channel)
d = self.servlet.on_POST(self.request)
return self.assertFailure(d, SynapseError) self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(
json.loads(channel.result["body"])["error"], "Invalid username"
)
@defer.inlineCallbacks
def test_POST_user_valid(self): def test_POST_user_valid(self):
user_id = "@kermit:muppet" user_id = "@kermit:muppet"
token = "kermits_access_token" token = "kermits_access_token"
device_id = "frogfone" device_id = "frogfone"
self.request_data = json.dumps({ request_data = json.dumps(
"username": "kermit", {"username": "kermit", "password": "monkey", "device_id": device_id}
"password": "monkey",
"device_id": device_id,
})
self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (None, {
"username": "kermit",
"password": "monkey"
}, None)
self.registration_handler.register = Mock(return_value=(user_id, None))
self.auth_handler.get_access_token_for_user_id = Mock(
return_value=token
) )
self.device_handler.check_device_registered = \ self.registration_handler.check_username = Mock(return_value=True)
Mock(return_value=device_id) self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
self.registration_handler.register = Mock(return_value=(user_id, None))
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
self.device_handler.check_device_registered = Mock(return_value=device_id)
request, channel = make_request(b"POST", self.url, request_data)
request.render(self.resource)
wait_until_result(self.clock, channel)
(code, result) = yield self.servlet.on_POST(self.request)
self.assertEquals(code, 200)
det_data = { det_data = {
"user_id": user_id, "user_id": user_id,
"access_token": token, "access_token": token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id, "device_id": device_id,
} }
self.assertDictContainsSubset(det_data, result) self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
self.auth_handler.get_login_tuple_for_user_id( self.auth_handler.get_login_tuple_for_user_id(
user_id, device_id=device_id, initial_device_display_name=None) user_id, device_id=device_id, initial_device_display_name=None
)
def test_POST_disabled_registration(self): def test_POST_disabled_registration(self):
self.hs.config.enable_registration = False self.hs.config.enable_registration = False
self.request_data = json.dumps({ request_data = json.dumps({"username": "kermit", "password": "monkey"})
"username": "kermit",
"password": "monkey"
})
self.registration_handler.check_username = Mock(return_value=True) self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (None, { self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
"username": "kermit",
"password": "monkey"
}, None)
self.registration_handler.register = Mock(return_value=("@user:id", "t")) self.registration_handler.register = Mock(return_value=("@user:id", "t"))
d = self.servlet.on_POST(self.request)
return self.assertFailure(d, SynapseError) request, channel = make_request(b"POST", self.url, request_data)
request.render(self.resource)
wait_until_result(self.clock, channel)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(
json.loads(channel.result["body"])["error"],
"Registration has been disabled",
)
def test_POST_guest_registration(self):
user_id = "a@b"
self.hs.config.macaroon_secret_key = "test"
self.hs.config.allow_guest_access = True
self.registration_handler.register = Mock(return_value=(user_id, None))
request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
request.render(self.resource)
wait_until_result(self.clock, channel)
det_data = {
"user_id": user_id,
"home_server": self.hs.hostname,
"device_id": "guest_device",
}
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
def test_POST_disabled_guest_registration(self):
self.hs.config.allow_guest_access = False
request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
request.render(self.resource)
wait_until_result(self.clock, channel)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(
json.loads(channel.result["body"])["error"], "Guest access is disabled"
)

View file

@ -0,0 +1,83 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse.types
from synapse.http.server import JsonResource
from synapse.rest.client.v2_alpha import sync
from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
from tests.server import ThreadedMemoryReactorClock as MemoryReactorClock
from tests.server import make_request, setup_test_homeserver, wait_until_result
PATH_PREFIX = "/_matrix/client/v2_alpha"
class FilterTestCase(unittest.TestCase):
USER_ID = b"@apple:test"
TO_REGISTER = [sync]
def setUp(self):
self.clock = MemoryReactorClock()
self.hs_clock = Clock(self.clock)
self.hs = setup_test_homeserver(
http_client=None, clock=self.hs_clock, reactor=self.clock
)
self.auth = self.hs.get_auth()
def get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.USER_ID),
"token_id": 1,
"is_guest": False,
}
def get_user_by_req(request, allow_guest=False, rights="access"):
return synapse.types.create_requester(
UserID.from_string(self.USER_ID), 1, False, None
)
self.auth.get_user_by_access_token = get_user_by_access_token
self.auth.get_user_by_req = get_user_by_req
self.store = self.hs.get_datastore()
self.filtering = self.hs.get_filtering()
self.resource = JsonResource(self.hs)
for r in self.TO_REGISTER:
r.register_servlets(self.hs, self.resource)
def test_sync_argless(self):
request, channel = make_request(b"GET", b"/_matrix/client/r0/sync")
request.render(self.resource)
wait_until_result(self.clock, channel)
self.assertEqual(channel.result["code"], b"200")
self.assertTrue(
set(
[
"next_batch",
"rooms",
"presence",
"account_data",
"to_device",
"device_lists",
]
).issubset(set(channel.json_body.keys()))
)

View file

@ -80,6 +80,11 @@ def make_request(method, path, content=b""):
content, and return the Request and the Channel underneath. content, and return the Request and the Channel underneath.
""" """
# Decorate it to be the full path
if not path.startswith(b"/_matrix"):
path = b"/_matrix/client/r0/" + path
path = path.replace("//", "/")
if isinstance(content, text_type): if isinstance(content, text_type):
content = content.encode('utf8') content = content.encode('utf8')
@ -110,6 +115,11 @@ def wait_until_result(clock, channel, timeout=100):
clock.advance(0.1) clock.advance(0.1)
def render(request, resource, clock):
request.render(resource)
wait_until_result(clock, request._channel)
class ThreadedMemoryReactorClock(MemoryReactorClock): class ThreadedMemoryReactorClock(MemoryReactorClock):
""" """
A MemoryReactorClock that supports callFromThread. A MemoryReactorClock that supports callFromThread.

View file

@ -33,9 +33,11 @@ class JsonResourceTests(unittest.TestCase):
return (200, kwargs) return (200, kwargs)
res = JsonResource(self.homeserver) res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/foo/(?P<room_id>[^/]*)$")], _callback) res.register_paths(
"GET", [re.compile("^/_matrix/foo/(?P<room_id>[^/]*)$")], _callback
)
request, channel = make_request(b"GET", b"/foo/%E2%98%83?a=%E2%98%83") request, channel = make_request(b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83")
request.render(res) request.render(res)
self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]}) self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]})
@ -51,9 +53,9 @@ class JsonResourceTests(unittest.TestCase):
raise Exception("boo") raise Exception("boo")
res = JsonResource(self.homeserver) res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/foo$")], _callback) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
request, channel = make_request(b"GET", b"/foo") request, channel = make_request(b"GET", b"/_matrix/foo")
request.render(res) request.render(res)
self.assertEqual(channel.result["code"], b'500') self.assertEqual(channel.result["code"], b'500')
@ -74,9 +76,9 @@ class JsonResourceTests(unittest.TestCase):
return d return d
res = JsonResource(self.homeserver) res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/foo$")], _callback) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
request, channel = make_request(b"GET", b"/foo") request, channel = make_request(b"GET", b"/_matrix/foo")
request.render(res) request.render(res)
# No error has been raised yet # No error has been raised yet
@ -96,9 +98,9 @@ class JsonResourceTests(unittest.TestCase):
raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN) raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN)
res = JsonResource(self.homeserver) res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/foo$")], _callback) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
request, channel = make_request(b"GET", b"/foo") request, channel = make_request(b"GET", b"/_matrix/foo")
request.render(res) request.render(res)
self.assertEqual(channel.result["code"], b'403') self.assertEqual(channel.result["code"], b'403')
@ -118,9 +120,9 @@ class JsonResourceTests(unittest.TestCase):
self.fail("shouldn't ever get here") self.fail("shouldn't ever get here")
res = JsonResource(self.homeserver) res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/foo$")], _callback) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
request, channel = make_request(b"GET", b"/foobar") request, channel = make_request(b"GET", b"/_matrix/foobar")
request.render(res) request.render(res)
self.assertEqual(channel.result["code"], b'400') self.assertEqual(channel.result["code"], b'400')

View file

@ -109,6 +109,17 @@ class TestCase(unittest.TestCase):
except AssertionError as e: except AssertionError as e:
raise (type(e))(e.message + " for '.%s'" % key) raise (type(e))(e.message + " for '.%s'" % key)
def assert_dict(self, required, actual):
"""Does a partial assert of a dict.
Args:
required (dict): The keys and value which MUST be in 'actual'.
actual (dict): The test result. Extra keys will not be checked.
"""
for key in required:
self.assertEquals(required[key], actual[key],
msg="%s mismatch. %s" % (key, actual))
def DEBUG(target): def DEBUG(target):
"""A decorator to set the .loglevel attribute to logging.DEBUG. """A decorator to set the .loglevel attribute to logging.DEBUG.