mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-22 09:35:45 +03:00
Parse json validation (#16923)
Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
This commit is contained in:
parent
09f0957b36
commit
1d47532310
6 changed files with 220 additions and 47 deletions
1
changelog.d/16923.bugfix
Normal file
1
changelog.d/16923.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Return `400 M_NOT_JSON` upon receiving invalid JSON in query parameters across various client and admin endpoints, rather than an internal server error.
|
|
@ -23,6 +23,7 @@
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
|
import urllib.parse as urlparse
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
|
@ -450,6 +451,87 @@ def parse_string(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_json(
|
||||||
|
request: Request,
|
||||||
|
name: str,
|
||||||
|
default: Optional[dict] = None,
|
||||||
|
required: bool = False,
|
||||||
|
encoding: str = "ascii",
|
||||||
|
) -> Optional[JsonDict]:
|
||||||
|
"""
|
||||||
|
Parse a JSON parameter from the request query string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: the twisted HTTP request.
|
||||||
|
name: the name of the query parameter.
|
||||||
|
default: value to use if the parameter is absent,
|
||||||
|
defaults to None.
|
||||||
|
required: whether to raise a 400 SynapseError if the
|
||||||
|
parameter is absent, defaults to False.
|
||||||
|
encoding: The encoding to decode the string content with.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A JSON value, or `default` if the named query parameter was not found
|
||||||
|
and `required` was False.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SynapseError if the parameter is absent and required, or if the
|
||||||
|
parameter is present and not a JSON object.
|
||||||
|
"""
|
||||||
|
args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore
|
||||||
|
return parse_json_from_args(
|
||||||
|
args,
|
||||||
|
name,
|
||||||
|
default,
|
||||||
|
required=required,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_json_from_args(
|
||||||
|
args: Mapping[bytes, Sequence[bytes]],
|
||||||
|
name: str,
|
||||||
|
default: Optional[dict] = None,
|
||||||
|
required: bool = False,
|
||||||
|
encoding: str = "ascii",
|
||||||
|
) -> Optional[JsonDict]:
|
||||||
|
"""
|
||||||
|
Parse a JSON parameter from the request query string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: a mapping of request args as bytes to a list of bytes (e.g. request.args).
|
||||||
|
name: the name of the query parameter.
|
||||||
|
default: value to use if the parameter is absent,
|
||||||
|
defaults to None.
|
||||||
|
required: whether to raise a 400 SynapseError if the
|
||||||
|
parameter is absent, defaults to False.
|
||||||
|
encoding: the encoding to decode the string content with.
|
||||||
|
|
||||||
|
A JSON value, or `default` if the named query parameter was not found
|
||||||
|
and `required` was False.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SynapseError if the parameter is absent and required, or if the
|
||||||
|
parameter is present and not a JSON object.
|
||||||
|
"""
|
||||||
|
name_bytes = name.encode("ascii")
|
||||||
|
|
||||||
|
if name_bytes not in args:
|
||||||
|
if not required:
|
||||||
|
return default
|
||||||
|
|
||||||
|
message = f"Missing required integer query parameter {name}"
|
||||||
|
raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM)
|
||||||
|
|
||||||
|
json_str = parse_string_from_args(args, name, required=True, encoding=encoding)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return json_decoder.decode(urlparse.unquote(json_str))
|
||||||
|
except Exception:
|
||||||
|
message = f"Query parameter {name} must be a valid JSON object"
|
||||||
|
raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.NOT_JSON)
|
||||||
|
|
||||||
|
|
||||||
EnumT = TypeVar("EnumT", bound=enum.Enum)
|
EnumT = TypeVar("EnumT", bound=enum.Enum)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,6 @@
|
||||||
import logging
|
import logging
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
|
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
|
||||||
from urllib import parse as urlparse
|
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
@ -38,6 +37,7 @@ from synapse.http.servlet import (
|
||||||
assert_params_in_dict,
|
assert_params_in_dict,
|
||||||
parse_enum,
|
parse_enum,
|
||||||
parse_integer,
|
parse_integer,
|
||||||
|
parse_json,
|
||||||
parse_json_object_from_request,
|
parse_json_object_from_request,
|
||||||
parse_string,
|
parse_string,
|
||||||
)
|
)
|
||||||
|
@ -51,7 +51,6 @@ from synapse.storage.databases.main.room import RoomSortOrder
|
||||||
from synapse.streams.config import PaginationConfig
|
from synapse.streams.config import PaginationConfig
|
||||||
from synapse.types import JsonDict, RoomID, ScheduledTask, UserID, create_requester
|
from synapse.types import JsonDict, RoomID, ScheduledTask, UserID, create_requester
|
||||||
from synapse.types.state import StateFilter
|
from synapse.types.state import StateFilter
|
||||||
from synapse.util import json_decoder
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.api.auth import Auth
|
from synapse.api.auth import Auth
|
||||||
|
@ -776,14 +775,8 @@ class RoomEventContextServlet(RestServlet):
|
||||||
limit = parse_integer(request, "limit", default=10)
|
limit = parse_integer(request, "limit", default=10)
|
||||||
|
|
||||||
# picking the API shape for symmetry with /messages
|
# picking the API shape for symmetry with /messages
|
||||||
filter_str = parse_string(request, "filter", encoding="utf-8")
|
filter_json = parse_json(request, "filter", encoding="utf-8")
|
||||||
if filter_str:
|
event_filter = Filter(self._hs, filter_json) if filter_json else None
|
||||||
filter_json = urlparse.unquote(filter_str)
|
|
||||||
event_filter: Optional[Filter] = Filter(
|
|
||||||
self._hs, json_decoder.decode(filter_json)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
event_filter = None
|
|
||||||
|
|
||||||
event_context = await self.room_context_handler.get_event_context(
|
event_context = await self.room_context_handler.get_event_context(
|
||||||
requester,
|
requester,
|
||||||
|
@ -914,21 +907,16 @@ class RoomMessagesRestServlet(RestServlet):
|
||||||
)
|
)
|
||||||
# Twisted will have processed the args by now.
|
# Twisted will have processed the args by now.
|
||||||
assert request.args is not None
|
assert request.args is not None
|
||||||
|
|
||||||
|
filter_json = parse_json(request, "filter", encoding="utf-8")
|
||||||
|
event_filter = Filter(self._hs, filter_json) if filter_json else None
|
||||||
|
|
||||||
as_client_event = b"raw" not in request.args
|
as_client_event = b"raw" not in request.args
|
||||||
filter_str = parse_string(request, "filter", encoding="utf-8")
|
|
||||||
if filter_str:
|
|
||||||
filter_json = urlparse.unquote(filter_str)
|
|
||||||
event_filter: Optional[Filter] = Filter(
|
|
||||||
self._hs, json_decoder.decode(filter_json)
|
|
||||||
)
|
|
||||||
if (
|
if (
|
||||||
event_filter
|
event_filter
|
||||||
and event_filter.filter_json.get("event_format", "client")
|
and event_filter.filter_json.get("event_format", "client") == "federation"
|
||||||
== "federation"
|
|
||||||
):
|
):
|
||||||
as_client_event = False
|
as_client_event = False
|
||||||
else:
|
|
||||||
event_filter = None
|
|
||||||
|
|
||||||
msgs = await self._pagination_handler.get_messages(
|
msgs = await self._pagination_handler.get_messages(
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
|
|
|
@ -52,6 +52,7 @@ from synapse.http.servlet import (
|
||||||
parse_boolean,
|
parse_boolean,
|
||||||
parse_enum,
|
parse_enum,
|
||||||
parse_integer,
|
parse_integer,
|
||||||
|
parse_json,
|
||||||
parse_json_object_from_request,
|
parse_json_object_from_request,
|
||||||
parse_string,
|
parse_string,
|
||||||
parse_strings_from_args,
|
parse_strings_from_args,
|
||||||
|
@ -65,7 +66,6 @@ from synapse.rest.client.transactions import HttpTransactionCache
|
||||||
from synapse.streams.config import PaginationConfig
|
from synapse.streams.config import PaginationConfig
|
||||||
from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID
|
from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID
|
||||||
from synapse.types.state import StateFilter
|
from synapse.types.state import StateFilter
|
||||||
from synapse.util import json_decoder
|
|
||||||
from synapse.util.cancellation import cancellable
|
from synapse.util.cancellation import cancellable
|
||||||
from synapse.util.stringutils import parse_and_validate_server_name, random_string
|
from synapse.util.stringutils import parse_and_validate_server_name, random_string
|
||||||
|
|
||||||
|
@ -703,21 +703,16 @@ class RoomMessageListRestServlet(RestServlet):
|
||||||
)
|
)
|
||||||
# Twisted will have processed the args by now.
|
# Twisted will have processed the args by now.
|
||||||
assert request.args is not None
|
assert request.args is not None
|
||||||
|
|
||||||
|
filter_json = parse_json(request, "filter", encoding="utf-8")
|
||||||
|
event_filter = Filter(self._hs, filter_json) if filter_json else None
|
||||||
|
|
||||||
as_client_event = b"raw" not in request.args
|
as_client_event = b"raw" not in request.args
|
||||||
filter_str = parse_string(request, "filter", encoding="utf-8")
|
|
||||||
if filter_str:
|
|
||||||
filter_json = urlparse.unquote(filter_str)
|
|
||||||
event_filter: Optional[Filter] = Filter(
|
|
||||||
self._hs, json_decoder.decode(filter_json)
|
|
||||||
)
|
|
||||||
if (
|
if (
|
||||||
event_filter
|
event_filter
|
||||||
and event_filter.filter_json.get("event_format", "client")
|
and event_filter.filter_json.get("event_format", "client") == "federation"
|
||||||
== "federation"
|
|
||||||
):
|
):
|
||||||
as_client_event = False
|
as_client_event = False
|
||||||
else:
|
|
||||||
event_filter = None
|
|
||||||
|
|
||||||
msgs = await self.pagination_handler.get_messages(
|
msgs = await self.pagination_handler.get_messages(
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
|
@ -898,14 +893,8 @@ class RoomEventContextServlet(RestServlet):
|
||||||
limit = parse_integer(request, "limit", default=10)
|
limit = parse_integer(request, "limit", default=10)
|
||||||
|
|
||||||
# picking the API shape for symmetry with /messages
|
# picking the API shape for symmetry with /messages
|
||||||
filter_str = parse_string(request, "filter", encoding="utf-8")
|
filter_json = parse_json(request, "filter", encoding="utf-8")
|
||||||
if filter_str:
|
event_filter = Filter(self._hs, filter_json) if filter_json else None
|
||||||
filter_json = urlparse.unquote(filter_str)
|
|
||||||
event_filter: Optional[Filter] = Filter(
|
|
||||||
self._hs, json_decoder.decode(filter_json)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
event_filter = None
|
|
||||||
|
|
||||||
event_context = await self.room_context_handler.get_event_context(
|
event_context = await self.room_context_handler.get_event_context(
|
||||||
requester, room_id, event_id, limit, event_filter
|
requester, room_id, event_id, limit, event_filter
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
from http import HTTPStatus
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from unittest.mock import AsyncMock, Mock
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
|
@ -2190,6 +2191,33 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
|
||||||
chunk = channel.json_body["chunk"]
|
chunk = channel.json_body["chunk"]
|
||||||
self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])
|
self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])
|
||||||
|
|
||||||
|
def test_room_message_filter_query_validation(self) -> None:
|
||||||
|
# Test json validation in (filter) query parameter.
|
||||||
|
# Does not test the validity of the filter, only the json validation.
|
||||||
|
|
||||||
|
# Check Get with valid json filter parameter, expect 200.
|
||||||
|
valid_filter_str = '{"types": ["m.room.message"]}'
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"/_synapse/admin/v1/rooms/{self.room_id}/messages?dir=b&filter={valid_filter_str}",
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
||||||
|
|
||||||
|
# Check Get with invalid json filter parameter, expect 400 NOT_JSON.
|
||||||
|
invalid_filter_str = "}}}{}"
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"/_synapse/admin/v1/rooms/{self.room_id}/messages?dir=b&filter={invalid_filter_str}",
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
|
||||||
|
self.assertEqual(
|
||||||
|
channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
|
class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
|
||||||
servlets = [
|
servlets = [
|
||||||
|
@ -2522,6 +2550,39 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
|
||||||
else:
|
else:
|
||||||
self.fail("Event %s from events_after not found" % j)
|
self.fail("Event %s from events_after not found" % j)
|
||||||
|
|
||||||
|
def test_room_event_context_filter_query_validation(self) -> None:
|
||||||
|
# Test json validation in (filter) query parameter.
|
||||||
|
# Does not test the validity of the filter, only the json validation.
|
||||||
|
|
||||||
|
# Create a user with room and event_id.
|
||||||
|
user_id = self.register_user("test", "test")
|
||||||
|
user_tok = self.login("test", "test")
|
||||||
|
room_id = self.helper.create_room_as(user_id, tok=user_tok)
|
||||||
|
event_id = self.helper.send(room_id, "message 1", tok=user_tok)["event_id"]
|
||||||
|
|
||||||
|
# Check Get with valid json filter parameter, expect 200.
|
||||||
|
valid_filter_str = '{"types": ["m.room.message"]}'
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"/_synapse/admin/v1/rooms/{room_id}/context/{event_id}?filter={valid_filter_str}",
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
||||||
|
|
||||||
|
# Check Get with invalid json filter parameter, expect 400 NOT_JSON.
|
||||||
|
invalid_filter_str = "}}}{}"
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"/_synapse/admin/v1/rooms/{room_id}/context/{event_id}?filter={invalid_filter_str}",
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
|
||||||
|
self.assertEqual(
|
||||||
|
channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
|
class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
|
||||||
servlets = [
|
servlets = [
|
||||||
|
|
|
@ -2175,6 +2175,31 @@ class RoomMessageListTestCase(RoomBase):
|
||||||
chunk = channel.json_body["chunk"]
|
chunk = channel.json_body["chunk"]
|
||||||
self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])
|
self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])
|
||||||
|
|
||||||
|
def test_room_message_filter_query_validation(self) -> None:
|
||||||
|
# Test json validation in (filter) query parameter.
|
||||||
|
# Does not test the validity of the filter, only the json validation.
|
||||||
|
|
||||||
|
# Check Get with valid json filter parameter, expect 200.
|
||||||
|
valid_filter_str = '{"types": ["m.room.message"]}'
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"/rooms/{self.room_id}/messages?access_token=x&dir=b&filter={valid_filter_str}",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
||||||
|
|
||||||
|
# Check Get with invalid json filter parameter, expect 400 NOT_JSON.
|
||||||
|
invalid_filter_str = "}}}{}"
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"/rooms/{self.room_id}/messages?access_token=x&dir=b&filter={invalid_filter_str}",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
|
||||||
|
self.assertEqual(
|
||||||
|
channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RoomMessageFilterTestCase(RoomBase):
|
class RoomMessageFilterTestCase(RoomBase):
|
||||||
"""Tests /rooms/$room_id/messages REST events."""
|
"""Tests /rooms/$room_id/messages REST events."""
|
||||||
|
@ -3213,6 +3238,33 @@ class ContextTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertDictEqual(events_after[0].get("content"), {}, events_after[0])
|
self.assertDictEqual(events_after[0].get("content"), {}, events_after[0])
|
||||||
self.assertEqual(events_after[1].get("content"), {}, events_after[1])
|
self.assertEqual(events_after[1].get("content"), {}, events_after[1])
|
||||||
|
|
||||||
|
def test_room_event_context_filter_query_validation(self) -> None:
|
||||||
|
# Test json validation in (filter) query parameter.
|
||||||
|
# Does not test the validity of the filter, only the json validation.
|
||||||
|
event_id = self.helper.send(self.room_id, "message 7", tok=self.tok)["event_id"]
|
||||||
|
|
||||||
|
# Check Get with valid json filter parameter, expect 200.
|
||||||
|
valid_filter_str = '{"types": ["m.room.message"]}'
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"/rooms/{self.room_id}/context/{event_id}?filter={valid_filter_str}",
|
||||||
|
access_token=self.tok,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
||||||
|
|
||||||
|
# Check Get with invalid json filter parameter, expect 400 NOT_JSON.
|
||||||
|
invalid_filter_str = "}}}{}"
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"/rooms/{self.room_id}/context/{event_id}?filter={invalid_filter_str}",
|
||||||
|
access_token=self.tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
|
||||||
|
self.assertEqual(
|
||||||
|
channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RoomAliasListTestCase(unittest.HomeserverTestCase):
|
class RoomAliasListTestCase(unittest.HomeserverTestCase):
|
||||||
servlets = [
|
servlets = [
|
||||||
|
|
Loading…
Reference in a new issue