mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-28 07:00:51 +03:00
Fix destination_is
errors seen in sentry. (#13041)
* Rename test_fedclient to match its source file * Require at least one destination to be truthy * Explicitly validate user ID in profile endpoint GETs Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
parent
aef398457f
commit
c99b511db9
7 changed files with 59 additions and 8 deletions
2
changelog.d/13041.bugfix
Normal file
2
changelog.d/13041.bugfix
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
Fix a bug introduced in Synapse 1.58 where profile requests for a malformed user ID would ccause an internal error. Synapse now returns 400 Bad Request in this situation.
|
||||||
|
|
|
@ -731,8 +731,11 @@ class MatrixFederationHttpClient:
|
||||||
Returns:
|
Returns:
|
||||||
A list of headers to be added as "Authorization:" headers
|
A list of headers to be added as "Authorization:" headers
|
||||||
"""
|
"""
|
||||||
if destination is None and destination_is is None:
|
if not destination and not destination_is:
|
||||||
raise ValueError("destination and destination_is cannot both be None!")
|
raise ValueError(
|
||||||
|
"At least one of the arguments destination and destination_is "
|
||||||
|
"must be a nonempty bytestring."
|
||||||
|
)
|
||||||
|
|
||||||
request: JsonDict = {
|
request: JsonDict = {
|
||||||
"method": method.decode("ascii"),
|
"method": method.decode("ascii"),
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
""" This module contains REST servlets to do with profile: /profile/<paths> """
|
""" This module contains REST servlets to do with profile: /profile/<paths> """
|
||||||
|
from http import HTTPStatus
|
||||||
from typing import TYPE_CHECKING, Tuple
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
|
@ -45,8 +45,12 @@ class ProfileDisplaynameRestServlet(RestServlet):
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user = requester.user
|
requester_user = requester.user
|
||||||
|
|
||||||
user = UserID.from_string(user_id)
|
if not UserID.is_valid(user_id):
|
||||||
|
raise SynapseError(
|
||||||
|
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
|
||||||
|
)
|
||||||
|
|
||||||
|
user = UserID.from_string(user_id)
|
||||||
await self.profile_handler.check_profile_query_allowed(user, requester_user)
|
await self.profile_handler.check_profile_query_allowed(user, requester_user)
|
||||||
|
|
||||||
displayname = await self.profile_handler.get_displayname(user)
|
displayname = await self.profile_handler.get_displayname(user)
|
||||||
|
@ -98,8 +102,12 @@ class ProfileAvatarURLRestServlet(RestServlet):
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user = requester.user
|
requester_user = requester.user
|
||||||
|
|
||||||
user = UserID.from_string(user_id)
|
if not UserID.is_valid(user_id):
|
||||||
|
raise SynapseError(
|
||||||
|
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
|
||||||
|
)
|
||||||
|
|
||||||
|
user = UserID.from_string(user_id)
|
||||||
await self.profile_handler.check_profile_query_allowed(user, requester_user)
|
await self.profile_handler.check_profile_query_allowed(user, requester_user)
|
||||||
|
|
||||||
avatar_url = await self.profile_handler.get_avatar_url(user)
|
avatar_url = await self.profile_handler.get_avatar_url(user)
|
||||||
|
@ -150,8 +158,12 @@ class ProfileRestServlet(RestServlet):
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user = requester.user
|
requester_user = requester.user
|
||||||
|
|
||||||
user = UserID.from_string(user_id)
|
if not UserID.is_valid(user_id):
|
||||||
|
raise SynapseError(
|
||||||
|
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
|
||||||
|
)
|
||||||
|
|
||||||
|
user = UserID.from_string(user_id)
|
||||||
await self.profile_handler.check_profile_query_allowed(user, requester_user)
|
await self.profile_handler.check_profile_query_allowed(user, requester_user)
|
||||||
|
|
||||||
displayname = await self.profile_handler.get_displayname(user)
|
displayname = await self.profile_handler.get_displayname(user)
|
||||||
|
|
|
@ -267,7 +267,6 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
|
||||||
)
|
)
|
||||||
|
|
||||||
domain = parts[1]
|
domain = parts[1]
|
||||||
|
|
||||||
# This code will need changing if we want to support multiple domain
|
# This code will need changing if we want to support multiple domain
|
||||||
# names on one HS
|
# names on one HS
|
||||||
return cls(localpart=parts[0], domain=domain)
|
return cls(localpart=parts[0], domain=domain)
|
||||||
|
@ -279,6 +278,8 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_valid(cls: Type[DS], s: str) -> bool:
|
def is_valid(cls: Type[DS], s: str) -> bool:
|
||||||
"""Parses the input string and attempts to ensure it is valid."""
|
"""Parses the input string and attempts to ensure it is valid."""
|
||||||
|
# TODO: this does not reject an empty localpart or an overly-long string.
|
||||||
|
# See https://spec.matrix.org/v1.2/appendices/#identifier-grammar
|
||||||
try:
|
try:
|
||||||
obj = cls.from_string(s)
|
obj = cls.from_string(s)
|
||||||
# Apply additional validation to the domain. This is only done
|
# Apply additional validation to the domain. This is only done
|
||||||
|
|
|
@ -617,3 +617,17 @@ class FederationClientTests(HomeserverTestCase):
|
||||||
self.assertIsInstance(f.value, RequestSendFailed)
|
self.assertIsInstance(f.value, RequestSendFailed)
|
||||||
|
|
||||||
self.assertTrue(transport.disconnecting)
|
self.assertTrue(transport.disconnecting)
|
||||||
|
|
||||||
|
def test_build_auth_headers_rejects_falsey_destinations(self) -> None:
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.cl.build_auth_headers(None, b"GET", b"https://example.com")
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.cl.build_auth_headers(b"", b"GET", b"https://example.com")
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.cl.build_auth_headers(
|
||||||
|
None, b"GET", b"https://example.com", destination_is=b""
|
||||||
|
)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.cl.build_auth_headers(
|
||||||
|
b"", b"GET", b"https://example.com", destination_is=b""
|
||||||
|
)
|
|
@ -13,6 +13,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tests REST events for /profile paths."""
|
"""Tests REST events for /profile paths."""
|
||||||
|
import urllib.parse
|
||||||
|
from http import HTTPStatus
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
@ -49,6 +51,12 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||||
res = self._get_displayname()
|
res = self._get_displayname()
|
||||||
self.assertEqual(res, "owner")
|
self.assertEqual(res, "owner")
|
||||||
|
|
||||||
|
def test_get_displayname_rejects_bad_username(self) -> None:
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET", f"/profile/{urllib.parse.quote('@alice:')}/displayname"
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||||
|
|
||||||
def test_set_displayname(self) -> None:
|
def test_set_displayname(self) -> None:
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"PUT",
|
"PUT",
|
||||||
|
|
|
@ -26,10 +26,21 @@ class UserIDTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual("test", user.domain)
|
self.assertEqual("test", user.domain)
|
||||||
self.assertEqual(True, self.hs.is_mine(user))
|
self.assertEqual(True, self.hs.is_mine(user))
|
||||||
|
|
||||||
def test_pase_empty(self):
|
def test_parse_rejects_empty_id(self):
|
||||||
with self.assertRaises(SynapseError):
|
with self.assertRaises(SynapseError):
|
||||||
UserID.from_string("")
|
UserID.from_string("")
|
||||||
|
|
||||||
|
def test_parse_rejects_missing_sigil(self):
|
||||||
|
with self.assertRaises(SynapseError):
|
||||||
|
UserID.from_string("alice:example.com")
|
||||||
|
|
||||||
|
def test_parse_rejects_missing_separator(self):
|
||||||
|
with self.assertRaises(SynapseError):
|
||||||
|
UserID.from_string("@alice.example.com")
|
||||||
|
|
||||||
|
def test_validation_rejects_missing_domain(self):
|
||||||
|
self.assertFalse(UserID.is_valid("@alice:"))
|
||||||
|
|
||||||
def test_build(self):
|
def test_build(self):
|
||||||
user = UserID("5678efgh", "my.domain")
|
user = UserID("5678efgh", "my.domain")
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue