Parse json validation (#16923)

Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
This commit is contained in:
Gordan Trevis 2024-04-18 14:57:38 +02:00 committed by GitHub
parent 09f0957b36
commit 1d47532310
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 220 additions and 47 deletions

1
changelog.d/16923.bugfix Normal file
View file

@ -0,0 +1 @@
Return `400 M_NOT_JSON` upon receiving invalid JSON in query parameters across various client and admin endpoints, rather than an internal server error.

View file

@ -23,6 +23,7 @@
import enum import enum
import logging import logging
import urllib.parse as urlparse
from http import HTTPStatus from http import HTTPStatus
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -450,6 +451,87 @@ def parse_string(
) )
def parse_json(
request: Request,
name: str,
default: Optional[dict] = None,
required: bool = False,
encoding: str = "ascii",
) -> Optional[JsonDict]:
"""
Parse a JSON parameter from the request query string.
Args:
request: the twisted HTTP request.
name: the name of the query parameter.
default: value to use if the parameter is absent,
defaults to None.
required: whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
encoding: The encoding to decode the string content with.
Returns:
A JSON value, or `default` if the named query parameter was not found
and `required` was False.
Raises:
SynapseError if the parameter is absent and required, or if the
parameter is present and not a JSON object.
"""
args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore
return parse_json_from_args(
args,
name,
default,
required=required,
encoding=encoding,
)
def parse_json_from_args(
args: Mapping[bytes, Sequence[bytes]],
name: str,
default: Optional[dict] = None,
required: bool = False,
encoding: str = "ascii",
) -> Optional[JsonDict]:
"""
Parse a JSON parameter from the request query string.
Args:
args: a mapping of request args as bytes to a list of bytes (e.g. request.args).
name: the name of the query parameter.
default: value to use if the parameter is absent,
defaults to None.
required: whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
encoding: the encoding to decode the string content with.
A JSON value, or `default` if the named query parameter was not found
and `required` was False.
Raises:
SynapseError if the parameter is absent and required, or if the
parameter is present and not a JSON object.
"""
name_bytes = name.encode("ascii")
if name_bytes not in args:
if not required:
return default
message = f"Missing required integer query parameter {name}"
raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM)
json_str = parse_string_from_args(args, name, required=True, encoding=encoding)
try:
return json_decoder.decode(urlparse.unquote(json_str))
except Exception:
message = f"Query parameter {name} must be a valid JSON object"
raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.NOT_JSON)
EnumT = TypeVar("EnumT", bound=enum.Enum) EnumT = TypeVar("EnumT", bound=enum.Enum)

View file

@ -21,7 +21,6 @@
import logging import logging
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple, cast from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from urllib import parse as urlparse
import attr import attr
@ -38,6 +37,7 @@ from synapse.http.servlet import (
assert_params_in_dict, assert_params_in_dict,
parse_enum, parse_enum,
parse_integer, parse_integer,
parse_json,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
@ -51,7 +51,6 @@ from synapse.storage.databases.main.room import RoomSortOrder
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, RoomID, ScheduledTask, UserID, create_requester from synapse.types import JsonDict, RoomID, ScheduledTask, UserID, create_requester
from synapse.types.state import StateFilter from synapse.types.state import StateFilter
from synapse.util import json_decoder
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.api.auth import Auth from synapse.api.auth import Auth
@ -776,14 +775,8 @@ class RoomEventContextServlet(RestServlet):
limit = parse_integer(request, "limit", default=10) limit = parse_integer(request, "limit", default=10)
# picking the API shape for symmetry with /messages # picking the API shape for symmetry with /messages
filter_str = parse_string(request, "filter", encoding="utf-8") filter_json = parse_json(request, "filter", encoding="utf-8")
if filter_str: event_filter = Filter(self._hs, filter_json) if filter_json else None
filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
else:
event_filter = None
event_context = await self.room_context_handler.get_event_context( event_context = await self.room_context_handler.get_event_context(
requester, requester,
@ -914,21 +907,16 @@ class RoomMessagesRestServlet(RestServlet):
) )
# Twisted will have processed the args by now. # Twisted will have processed the args by now.
assert request.args is not None assert request.args is not None
filter_json = parse_json(request, "filter", encoding="utf-8")
event_filter = Filter(self._hs, filter_json) if filter_json else None
as_client_event = b"raw" not in request.args as_client_event = b"raw" not in request.args
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
if ( if (
event_filter event_filter
and event_filter.filter_json.get("event_format", "client") and event_filter.filter_json.get("event_format", "client") == "federation"
== "federation"
): ):
as_client_event = False as_client_event = False
else:
event_filter = None
msgs = await self._pagination_handler.get_messages( msgs = await self._pagination_handler.get_messages(
room_id=room_id, room_id=room_id,

View file

@ -52,6 +52,7 @@ from synapse.http.servlet import (
parse_boolean, parse_boolean,
parse_enum, parse_enum,
parse_integer, parse_integer,
parse_json,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
parse_strings_from_args, parse_strings_from_args,
@ -65,7 +66,6 @@ from synapse.rest.client.transactions import HttpTransactionCache
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID
from synapse.types.state import StateFilter from synapse.types.state import StateFilter
from synapse.util import json_decoder
from synapse.util.cancellation import cancellable from synapse.util.cancellation import cancellable
from synapse.util.stringutils import parse_and_validate_server_name, random_string from synapse.util.stringutils import parse_and_validate_server_name, random_string
@ -703,21 +703,16 @@ class RoomMessageListRestServlet(RestServlet):
) )
# Twisted will have processed the args by now. # Twisted will have processed the args by now.
assert request.args is not None assert request.args is not None
filter_json = parse_json(request, "filter", encoding="utf-8")
event_filter = Filter(self._hs, filter_json) if filter_json else None
as_client_event = b"raw" not in request.args as_client_event = b"raw" not in request.args
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
if ( if (
event_filter event_filter
and event_filter.filter_json.get("event_format", "client") and event_filter.filter_json.get("event_format", "client") == "federation"
== "federation"
): ):
as_client_event = False as_client_event = False
else:
event_filter = None
msgs = await self.pagination_handler.get_messages( msgs = await self.pagination_handler.get_messages(
room_id=room_id, room_id=room_id,
@ -898,14 +893,8 @@ class RoomEventContextServlet(RestServlet):
limit = parse_integer(request, "limit", default=10) limit = parse_integer(request, "limit", default=10)
# picking the API shape for symmetry with /messages # picking the API shape for symmetry with /messages
filter_str = parse_string(request, "filter", encoding="utf-8") filter_json = parse_json(request, "filter", encoding="utf-8")
if filter_str: event_filter = Filter(self._hs, filter_json) if filter_json else None
filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
else:
event_filter = None
event_context = await self.room_context_handler.get_event_context( event_context = await self.room_context_handler.get_event_context(
requester, room_id, event_id, limit, event_filter requester, room_id, event_id, limit, event_filter

View file

@ -21,6 +21,7 @@
import json import json
import time import time
import urllib.parse import urllib.parse
from http import HTTPStatus
from typing import List, Optional from typing import List, Optional
from unittest.mock import AsyncMock, Mock from unittest.mock import AsyncMock, Mock
@ -2190,6 +2191,33 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
chunk = channel.json_body["chunk"] chunk = channel.json_body["chunk"]
self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])
def test_room_message_filter_query_validation(self) -> None:
# Test json validation in (filter) query parameter.
# Does not test the validity of the filter, only the json validation.
# Check Get with valid json filter parameter, expect 200.
valid_filter_str = '{"types": ["m.room.message"]}'
channel = self.make_request(
"GET",
f"/_synapse/admin/v1/rooms/{self.room_id}/messages?dir=b&filter={valid_filter_str}",
access_token=self.admin_user_tok,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# Check Get with invalid json filter parameter, expect 400 NOT_JSON.
invalid_filter_str = "}}}{}"
channel = self.make_request(
"GET",
f"/_synapse/admin/v1/rooms/{self.room_id}/messages?dir=b&filter={invalid_filter_str}",
access_token=self.admin_user_tok,
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
self.assertEqual(
channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
)
class JoinAliasRoomTestCase(unittest.HomeserverTestCase): class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
servlets = [ servlets = [
@ -2522,6 +2550,39 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
else: else:
self.fail("Event %s from events_after not found" % j) self.fail("Event %s from events_after not found" % j)
def test_room_event_context_filter_query_validation(self) -> None:
# Test json validation in (filter) query parameter.
# Does not test the validity of the filter, only the json validation.
# Create a user with room and event_id.
user_id = self.register_user("test", "test")
user_tok = self.login("test", "test")
room_id = self.helper.create_room_as(user_id, tok=user_tok)
event_id = self.helper.send(room_id, "message 1", tok=user_tok)["event_id"]
# Check Get with valid json filter parameter, expect 200.
valid_filter_str = '{"types": ["m.room.message"]}'
channel = self.make_request(
"GET",
f"/_synapse/admin/v1/rooms/{room_id}/context/{event_id}?filter={valid_filter_str}",
access_token=self.admin_user_tok,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# Check Get with invalid json filter parameter, expect 400 NOT_JSON.
invalid_filter_str = "}}}{}"
channel = self.make_request(
"GET",
f"/_synapse/admin/v1/rooms/{room_id}/context/{event_id}?filter={invalid_filter_str}",
access_token=self.admin_user_tok,
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
self.assertEqual(
channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
)
class MakeRoomAdminTestCase(unittest.HomeserverTestCase): class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
servlets = [ servlets = [

View file

@ -2175,6 +2175,31 @@ class RoomMessageListTestCase(RoomBase):
chunk = channel.json_body["chunk"] chunk = channel.json_body["chunk"]
self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])
def test_room_message_filter_query_validation(self) -> None:
# Test json validation in (filter) query parameter.
# Does not test the validity of the filter, only the json validation.
# Check Get with valid json filter parameter, expect 200.
valid_filter_str = '{"types": ["m.room.message"]}'
channel = self.make_request(
"GET",
f"/rooms/{self.room_id}/messages?access_token=x&dir=b&filter={valid_filter_str}",
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# Check Get with invalid json filter parameter, expect 400 NOT_JSON.
invalid_filter_str = "}}}{}"
channel = self.make_request(
"GET",
f"/rooms/{self.room_id}/messages?access_token=x&dir=b&filter={invalid_filter_str}",
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
self.assertEqual(
channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
)
class RoomMessageFilterTestCase(RoomBase): class RoomMessageFilterTestCase(RoomBase):
"""Tests /rooms/$room_id/messages REST events.""" """Tests /rooms/$room_id/messages REST events."""
@ -3213,6 +3238,33 @@ class ContextTestCase(unittest.HomeserverTestCase):
self.assertDictEqual(events_after[0].get("content"), {}, events_after[0]) self.assertDictEqual(events_after[0].get("content"), {}, events_after[0])
self.assertEqual(events_after[1].get("content"), {}, events_after[1]) self.assertEqual(events_after[1].get("content"), {}, events_after[1])
def test_room_event_context_filter_query_validation(self) -> None:
# Test json validation in (filter) query parameter.
# Does not test the validity of the filter, only the json validation.
event_id = self.helper.send(self.room_id, "message 7", tok=self.tok)["event_id"]
# Check Get with valid json filter parameter, expect 200.
valid_filter_str = '{"types": ["m.room.message"]}'
channel = self.make_request(
"GET",
f"/rooms/{self.room_id}/context/{event_id}?filter={valid_filter_str}",
access_token=self.tok,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# Check Get with invalid json filter parameter, expect 400 NOT_JSON.
invalid_filter_str = "}}}{}"
channel = self.make_request(
"GET",
f"/rooms/{self.room_id}/context/{event_id}?filter={invalid_filter_str}",
access_token=self.tok,
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
self.assertEqual(
channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
)
class RoomAliasListTestCase(unittest.HomeserverTestCase): class RoomAliasListTestCase(unittest.HomeserverTestCase):
servlets = [ servlets = [