Allow selecting "prejoin" events by state keys (#14642)

* Declare new config

* Parse new config

* Read new config

* Don't use trial/our TestCase where it's not needed

Before:

```
$ time trial tests/events/test_utils.py > /dev/null

real	0m2.277s
user	0m2.186s
sys	0m0.083s
```

After:
```
$ time trial tests/events/test_utils.py > /dev/null

real	0m0.566s
user	0m0.508s
sys	0m0.056s
```

* Helper to upsert to event fields

without exceeding size limits.

* Use helper when adding invite/knock state

Now that we allow admins to include events in prejoin room state with
arbitrary state keys, be a good Matrix citizen and ensure they don't
accidentally create an oversized event.

* Changelog

* Move StateFilter tests

should have done this in #14668

* Add extra methods to StateFilter

* Use StateFilter

* Ensure test file enforces typed defs; alphabetise

* Workaround surprising get_current_state_ids

* Whoops, fix mypy
This commit is contained in:
David Robertson 2022-12-13 00:54:46 +00:00 committed by GitHub
parent 3d87847ecc
commit e2a1adbf5d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 982 additions and 694 deletions

View file

@ -0,0 +1 @@
Allow selecting "prejoin" events by state keys in addition to event types.

View file

@ -2501,32 +2501,53 @@ Config settings related to the client/server API
--- ---
### `room_prejoin_state` ### `room_prejoin_state`
Controls for the state that is shared with users who receive an invite This setting controls the state that is shared with users upon receiving an
to a room. By default, the following state event types are shared with users who invite to a room, or in reply to a knock on a room. By default, the following
receive invites to the room: state events are shared with users:
- m.room.join_rules
- m.room.canonical_alias - `m.room.join_rules`
- m.room.avatar - `m.room.canonical_alias`
- m.room.encryption - `m.room.avatar`
- m.room.name - `m.room.encryption`
- m.room.create - `m.room.name`
- m.room.topic - `m.room.create`
- `m.room.topic`
To change the default behavior, use the following sub-options: To change the default behavior, use the following sub-options:
* `disable_default_event_types`: set to true to disable the above defaults. If this * `disable_default_event_types`: boolean. Set to `true` to disable the above
is enabled, only the event types listed in `additional_event_types` are shared. defaults. If this is enabled, only the event types listed in
Defaults to false. `additional_event_types` are shared. Defaults to `false`.
* `additional_event_types`: Additional state event types to share with users when they are invited * `additional_event_types`: A list of additional state events to include in the
to a room. By default, this list is empty (so only the default event types are shared). events to be shared. By default, this list is empty (so only the default event
types are shared).
Each entry in this list should be either a single string or a list of two
strings.
* A standalone string `t` represents all events with type `t` (i.e.
with no restrictions on state keys).
* A pair of strings `[t, s]` represents a single event with type `t` and
state key `s`. The same type can appear in two entries with different state
keys: in this situation, both state keys are included in prejoin state.
Example configuration: Example configuration:
```yaml ```yaml
room_prejoin_state: room_prejoin_state:
disable_default_event_types: true disable_default_event_types: false
additional_event_types: additional_event_types:
- org.example.custom.event.type # Share all events of type `org.example.custom.event.typeA`
- m.room.join_rules - org.example.custom.event.typeA
# Share only events of type `org.example.custom.event.typeB` whose
# state_key is "foo"
- ["org.example.custom.event.typeB", "foo"]
# Share only events of type `org.example.custom.event.typeC` whose
# state_key is "bar" or "baz"
- ["org.example.custom.event.typeC", "bar"]
- ["org.example.custom.event.typeC", "baz"]
``` ```
*Changed in Synapse 1.74:* admins can filter the events in prejoin state based
on their state key.
--- ---
### `track_puppeted_user_ips` ### `track_puppeted_user_ips`

View file

@ -89,6 +89,12 @@ disallow_untyped_defs = False
[mypy-tests.*] [mypy-tests.*]
disallow_untyped_defs = False disallow_untyped_defs = False
[mypy-tests.config.test_api]
disallow_untyped_defs = True
[mypy-tests.federation.transport.test_client]
disallow_untyped_defs = True
[mypy-tests.handlers.test_sso] [mypy-tests.handlers.test_sso]
disallow_untyped_defs = True disallow_untyped_defs = True
@ -101,7 +107,7 @@ disallow_untyped_defs = True
[mypy-tests.push.test_bulk_push_rule_evaluator] [mypy-tests.push.test_bulk_push_rule_evaluator]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-tests.test_server] [mypy-tests.rest.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-tests.state.test_profile] [mypy-tests.state.test_profile]
@ -110,10 +116,10 @@ disallow_untyped_defs = True
[mypy-tests.storage.*] [mypy-tests.storage.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-tests.rest.*] [mypy-tests.test_server]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-tests.federation.transport.test_client] [mypy-tests.types.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-tests.util.caches.*] [mypy-tests.util.caches.*]

View file

@ -33,6 +33,9 @@ def validate_config(
config: the configuration value to be validated config: the configuration value to be validated
config_path: the path within the config file. This will be used as a basis config_path: the path within the config file. This will be used as a basis
for the error message. for the error message.
Raises:
ConfigError, if validation fails.
""" """
try: try:
jsonschema.validate(config, json_schema) jsonschema.validate(config, json_schema)

View file

@ -13,12 +13,13 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Iterable from typing import Any, Iterable, Optional, Tuple
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.config._base import Config, ConfigError from synapse.config._base import Config, ConfigError
from synapse.config._util import validate_config from synapse.config._util import validate_config
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.types.state import StateFilter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -26,16 +27,20 @@ logger = logging.getLogger(__name__)
class ApiConfig(Config): class ApiConfig(Config):
section = "api" section = "api"
room_prejoin_state: StateFilter
track_puppetted_users_ips: bool
def read_config(self, config: JsonDict, **kwargs: Any) -> None: def read_config(self, config: JsonDict, **kwargs: Any) -> None:
validate_config(_MAIN_SCHEMA, config, ()) validate_config(_MAIN_SCHEMA, config, ())
self.room_prejoin_state = list(self._get_prejoin_state_types(config)) self.room_prejoin_state = StateFilter.from_types(
self._get_prejoin_state_entries(config)
)
self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False) self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False)
def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]: def _get_prejoin_state_entries(
"""Get the event types to include in the prejoin state self, config: JsonDict
) -> Iterable[Tuple[str, Optional[str]]]:
Parses the config and returns an iterable of the event types to be included. """Get the event types and state keys to include in the prejoin state."""
"""
room_prejoin_state_config = config.get("room_prejoin_state") or {} room_prejoin_state_config = config.get("room_prejoin_state") or {}
# backwards-compatibility support for room_invite_state_types # backwards-compatibility support for room_invite_state_types
@ -50,33 +55,39 @@ class ApiConfig(Config):
logger.warning(_ROOM_INVITE_STATE_TYPES_WARNING) logger.warning(_ROOM_INVITE_STATE_TYPES_WARNING)
yield from config["room_invite_state_types"] for event_type in config["room_invite_state_types"]:
yield event_type, None
return return
if not room_prejoin_state_config.get("disable_default_event_types"): if not room_prejoin_state_config.get("disable_default_event_types"):
yield from _DEFAULT_PREJOIN_STATE_TYPES yield from _DEFAULT_PREJOIN_STATE_TYPES_AND_STATE_KEYS
yield from room_prejoin_state_config.get("additional_event_types", []) for entry in room_prejoin_state_config.get("additional_event_types", []):
if isinstance(entry, str):
yield entry, None
else:
yield entry
_ROOM_INVITE_STATE_TYPES_WARNING = """\ _ROOM_INVITE_STATE_TYPES_WARNING = """\
WARNING: The 'room_invite_state_types' configuration setting is now deprecated, WARNING: The 'room_invite_state_types' configuration setting is now deprecated,
and replaced with 'room_prejoin_state'. New features may not work correctly and replaced with 'room_prejoin_state'. New features may not work correctly
unless 'room_invite_state_types' is removed. See the sample configuration file for unless 'room_invite_state_types' is removed. See the config documentation at
details of 'room_prejoin_state'. https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#room_prejoin_state
for details of 'room_prejoin_state'.
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
""" """
_DEFAULT_PREJOIN_STATE_TYPES = [ _DEFAULT_PREJOIN_STATE_TYPES_AND_STATE_KEYS = [
EventTypes.JoinRules, (EventTypes.JoinRules, ""),
EventTypes.CanonicalAlias, (EventTypes.CanonicalAlias, ""),
EventTypes.RoomAvatar, (EventTypes.RoomAvatar, ""),
EventTypes.RoomEncryption, (EventTypes.RoomEncryption, ""),
EventTypes.Name, (EventTypes.Name, ""),
# Per MSC1772. # Per MSC1772.
EventTypes.Create, (EventTypes.Create, ""),
# Per MSC3173. # Per MSC3173.
EventTypes.Topic, (EventTypes.Topic, ""),
] ]
@ -90,7 +101,17 @@ _ROOM_PREJOIN_STATE_CONFIG_SCHEMA = {
"disable_default_event_types": {"type": "boolean"}, "disable_default_event_types": {"type": "boolean"},
"additional_event_types": { "additional_event_types": {
"type": "array", "type": "array",
"items": {"type": "string"}, "items": {
"oneOf": [
{"type": "string"},
{
"type": "array",
"items": {"type": "string"},
"minItems": 2,
"maxItems": 2,
},
],
},
}, },
}, },
}, },

View file

@ -28,8 +28,14 @@ from typing import (
) )
import attr import attr
from canonicaljson import encode_canonical_json
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.constants import (
MAX_PDU_SIZE,
EventContentFields,
EventTypes,
RelationTypes,
)
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion from synapse.api.room_versions import RoomVersion
from synapse.types import JsonDict from synapse.types import JsonDict
@ -674,3 +680,27 @@ def validate_canonicaljson(value: Any) -> None:
elif not isinstance(value, (bool, str)) and value is not None: elif not isinstance(value, (bool, str)) and value is not None:
# Other potential JSON values (bool, None, str) are safe. # Other potential JSON values (bool, None, str) are safe.
raise SynapseError(400, "Unknown JSON value", Codes.BAD_JSON) raise SynapseError(400, "Unknown JSON value", Codes.BAD_JSON)
def maybe_upsert_event_field(
event: EventBase, container: JsonDict, key: str, value: object
) -> bool:
"""Upsert an event field, but only if this doesn't make the event too large.
Returns true iff the upsert took place.
"""
if key in container:
old_value: object = container[key]
container[key] = value
# NB: here and below, we assume that passing a non-None `time_now` argument to
# get_pdu_json doesn't increase the size of the encoded result.
upsert_okay = len(encode_canonical_json(event.get_pdu_json())) <= MAX_PDU_SIZE
if not upsert_okay:
container[key] = old_value
else:
container[key] = value
upsert_okay = len(encode_canonical_json(event.get_pdu_json())) <= MAX_PDU_SIZE
if not upsert_okay:
del container[key]
return upsert_okay

View file

@ -50,6 +50,7 @@ from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase, relation_from_event from synapse.events import EventBase, relation_from_event
from synapse.events.builder import EventBuilder from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.events.utils import maybe_upsert_event_field
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler from synapse.handlers.directory import DirectoryHandler
from synapse.logging import opentracing from synapse.logging import opentracing
@ -1739,12 +1740,15 @@ class EventCreationHandler:
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE: if event.content["membership"] == Membership.INVITE:
event.unsigned[ maybe_upsert_event_field(
"invite_room_state" event,
] = await self.store.get_stripped_room_state_from_event_context( event.unsigned,
context, "invite_room_state",
self.room_prejoin_state_types, await self.store.get_stripped_room_state_from_event_context(
membership_user_id=event.sender, context,
self.room_prejoin_state_types,
membership_user_id=event.sender,
),
) )
invitee = UserID.from_string(event.state_key) invitee = UserID.from_string(event.state_key)
@ -1762,11 +1766,14 @@ class EventCreationHandler:
event.signatures.update(returned_invite.signatures) event.signatures.update(returned_invite.signatures)
if event.content["membership"] == Membership.KNOCK: if event.content["membership"] == Membership.KNOCK:
event.unsigned[ maybe_upsert_event_field(
"knock_room_state" event,
] = await self.store.get_stripped_room_state_from_event_context( event.unsigned,
context, "knock_room_state",
self.room_prejoin_state_types, await self.store.get_stripped_room_state_from_event_context(
context,
self.room_prejoin_state_types,
),
) )
if event.type == EventTypes.Redaction: if event.type == EventTypes.Redaction:

View file

@ -16,11 +16,11 @@ import logging
import threading import threading
import weakref import weakref
from enum import Enum, auto from enum import Enum, auto
from itertools import chain
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Collection, Collection,
Container,
Dict, Dict,
Iterable, Iterable,
List, List,
@ -76,6 +76,7 @@ from synapse.storage.util.id_generators import (
) )
from synapse.storage.util.sequence import build_sequence_generator from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id from synapse.types import JsonDict, get_domain_from_id
from synapse.types.state import StateFilter
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
@ -879,7 +880,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_stripped_room_state_from_event_context( async def get_stripped_room_state_from_event_context(
self, self,
context: EventContext, context: EventContext,
state_types_to_include: Container[str], state_keys_to_include: StateFilter,
membership_user_id: Optional[str] = None, membership_user_id: Optional[str] = None,
) -> List[JsonDict]: ) -> List[JsonDict]:
""" """
@ -892,7 +893,7 @@ class EventsWorkerStore(SQLBaseStore):
Args: Args:
context: The event context to retrieve state of the room from. context: The event context to retrieve state of the room from.
state_types_to_include: The type of state events to include. state_keys_to_include: The state events to include, for each event type.
membership_user_id: An optional user ID to include the stripped membership state membership_user_id: An optional user ID to include the stripped membership state
events of. This is useful when generating the stripped state of a room for events of. This is useful when generating the stripped state of a room for
invites. We want to send membership events of the inviter, so that the invites. We want to send membership events of the inviter, so that the
@ -901,21 +902,25 @@ class EventsWorkerStore(SQLBaseStore):
Returns: Returns:
A list of dictionaries, each representing a stripped state event from the room. A list of dictionaries, each representing a stripped state event from the room.
""" """
current_state_ids = await context.get_current_state_ids() if membership_user_id:
types = chain(
state_keys_to_include.to_types(),
[(EventTypes.Member, membership_user_id)],
)
filter = StateFilter.from_types(types)
else:
filter = state_keys_to_include
selected_state_ids = await context.get_current_state_ids(filter)
# We know this event is not an outlier, so this must be # We know this event is not an outlier, so this must be
# non-None. # non-None.
assert current_state_ids is not None assert selected_state_ids is not None
# The state to include # Confusingly, get_current_state_events may return events that are discarded by
state_to_include_ids = [ # the filter, if they're in context._state_delta_due_to_event. Strip these away.
e_id selected_state_ids = filter.filter_state(selected_state_ids)
for k, e_id in current_state_ids.items()
if k[0] in state_types_to_include
or (membership_user_id and k == (EventTypes.Member, membership_user_id))
]
state_to_include = await self.get_events(state_to_include_ids) state_to_include = await self.get_events(selected_state_ids.values())
return [ return [
{ {

View file

@ -118,6 +118,15 @@ class StateFilter:
) )
) )
def to_types(self) -> Iterable[Tuple[str, Optional[str]]]:
"""The inverse to `from_types`."""
for (event_type, state_keys) in self.types.items():
if state_keys is None:
yield event_type, None
else:
for state_key in state_keys:
yield event_type, state_key
@staticmethod @staticmethod
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
"""Creates a filter that returns all non-member events, plus the member """Creates a filter that returns all non-member events, plus the member
@ -343,6 +352,15 @@ class StateFilter:
for s in state_keys for s in state_keys
] ]
def wildcard_types(self) -> List[str]:
"""Returns a list of event types which require us to fetch all state keys.
This will be empty unless `has_wildcards` returns True.
Returns:
A list of event types.
"""
return [t for t, state_keys in self.types.items() if state_keys is None]
def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
"""Return the filter split into two: one which assumes it's exclusively """Return the filter split into two: one which assumes it's exclusively
matching against member state, and one which assumes it's matching matching against member state, and one which assumes it's matching

145
tests/config/test_api.py Normal file
View file

@ -0,0 +1,145 @@
from unittest import TestCase as StdlibTestCase
import yaml
from synapse.config import ConfigError
from synapse.config.api import ApiConfig
from synapse.types.state import StateFilter
DEFAULT_PREJOIN_STATE_PAIRS = {
("m.room.join_rules", ""),
("m.room.canonical_alias", ""),
("m.room.avatar", ""),
("m.room.encryption", ""),
("m.room.name", ""),
("m.room.create", ""),
("m.room.topic", ""),
}
class TestRoomPrejoinState(StdlibTestCase):
def read_config(self, source: str) -> ApiConfig:
config = ApiConfig()
config.read_config(yaml.safe_load(source))
return config
def test_no_prejoin_state(self) -> None:
config = self.read_config("foo: bar")
self.assertFalse(config.room_prejoin_state.has_wildcards())
self.assertEqual(
set(config.room_prejoin_state.concrete_types()), DEFAULT_PREJOIN_STATE_PAIRS
)
def test_disable_default_event_types(self) -> None:
config = self.read_config(
"""
room_prejoin_state:
disable_default_event_types: true
"""
)
self.assertEqual(config.room_prejoin_state, StateFilter.none())
def test_event_without_state_key(self) -> None:
config = self.read_config(
"""
room_prejoin_state:
disable_default_event_types: true
additional_event_types:
- foo
"""
)
self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
self.assertEqual(config.room_prejoin_state.concrete_types(), [])
def test_event_with_specific_state_key(self) -> None:
config = self.read_config(
"""
room_prejoin_state:
disable_default_event_types: true
additional_event_types:
- [foo, bar]
"""
)
self.assertFalse(config.room_prejoin_state.has_wildcards())
self.assertEqual(
set(config.room_prejoin_state.concrete_types()),
{("foo", "bar")},
)
def test_repeated_event_with_specific_state_key(self) -> None:
config = self.read_config(
"""
room_prejoin_state:
disable_default_event_types: true
additional_event_types:
- [foo, bar]
- [foo, baz]
"""
)
self.assertFalse(config.room_prejoin_state.has_wildcards())
self.assertEqual(
set(config.room_prejoin_state.concrete_types()),
{("foo", "bar"), ("foo", "baz")},
)
def test_no_specific_state_key_overrides_specific_state_key(self) -> None:
config = self.read_config(
"""
room_prejoin_state:
disable_default_event_types: true
additional_event_types:
- [foo, bar]
- foo
"""
)
self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
self.assertEqual(config.room_prejoin_state.concrete_types(), [])
config = self.read_config(
"""
room_prejoin_state:
disable_default_event_types: true
additional_event_types:
- foo
- [foo, bar]
"""
)
self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
self.assertEqual(config.room_prejoin_state.concrete_types(), [])
def test_bad_event_type_entry_raises(self) -> None:
with self.assertRaises(ConfigError):
self.read_config(
"""
room_prejoin_state:
additional_event_types:
- []
"""
)
with self.assertRaises(ConfigError):
self.read_config(
"""
room_prejoin_state:
additional_event_types:
- [a]
"""
)
with self.assertRaises(ConfigError):
self.read_config(
"""
room_prejoin_state:
additional_event_types:
- [a, b, c]
"""
)
with self.assertRaises(ConfigError):
self.read_config(
"""
room_prejoin_state:
additional_event_types:
- [true, 1.23]
"""
)

View file

@ -12,19 +12,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest as stdlib_unittest
from synapse.api.constants import EventContentFields from synapse.api.constants import EventContentFields
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict from synapse.events import make_event_from_dict
from synapse.events.utils import ( from synapse.events.utils import (
SerializeEventConfig, SerializeEventConfig,
copy_and_fixup_power_levels_contents, copy_and_fixup_power_levels_contents,
maybe_upsert_event_field,
prune_event, prune_event,
serialize_event, serialize_event,
) )
from synapse.util.frozenutils import freeze from synapse.util.frozenutils import freeze
from tests import unittest
def MockEvent(**kwargs): def MockEvent(**kwargs):
if "event_id" not in kwargs: if "event_id" not in kwargs:
@ -34,7 +35,31 @@ def MockEvent(**kwargs):
return make_event_from_dict(kwargs) return make_event_from_dict(kwargs)
class PruneEventTestCase(unittest.TestCase): class TestMaybeUpsertEventField(stdlib_unittest.TestCase):
def test_update_okay(self) -> None:
event = make_event_from_dict({"event_id": "$1234"})
success = maybe_upsert_event_field(event, event.unsigned, "key", "value")
self.assertTrue(success)
self.assertEqual(event.unsigned["key"], "value")
def test_update_not_okay(self) -> None:
event = make_event_from_dict({"event_id": "$1234"})
LARGE_STRING = "a" * 100_000
success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING)
self.assertFalse(success)
self.assertNotIn("key", event.unsigned)
def test_update_not_okay_leaves_original_value(self) -> None:
event = make_event_from_dict(
{"event_id": "$1234", "unsigned": {"key": "value"}}
)
LARGE_STRING = "a" * 100_000
success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING)
self.assertFalse(success)
self.assertEqual(event.unsigned["key"], "value")
class PruneEventTestCase(stdlib_unittest.TestCase):
def run_test(self, evdict, matchdict, **kwargs): def run_test(self, evdict, matchdict, **kwargs):
""" """
Asserts that a new event constructed with `evdict` will look like Asserts that a new event constructed with `evdict` will look like
@ -391,7 +416,7 @@ class PruneEventTestCase(unittest.TestCase):
) )
class SerializeEventTestCase(unittest.TestCase): class SerializeEventTestCase(stdlib_unittest.TestCase):
def serialize(self, ev, fields): def serialize(self, ev, fields):
return serialize_event( return serialize_event(
ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields) ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields)
@ -513,7 +538,7 @@ class SerializeEventTestCase(unittest.TestCase):
) )
class CopyPowerLevelsContentTestCase(unittest.TestCase): class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.test_content = { self.test_content = {
"ban": 50, "ban": 50,

View file

@ -26,7 +26,7 @@ from synapse.types import JsonDict, RoomID, StateMap, UserID
from synapse.types.state import StateFilter from synapse.types.state import StateFilter
from synapse.util import Clock from synapse.util import Clock
from tests.unittest import HomeserverTestCase, TestCase from tests.unittest import HomeserverTestCase
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -494,624 +494,3 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
class StateFilterDifferenceTestCase(TestCase):
def assert_difference(
self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
) -> None:
self.assertEqual(
minuend.approx_difference(subtrahend),
expected,
f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
)
def test_state_filter_difference_no_include_other_minus_no_include_other(
self,
) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b do not have the
include_others flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=False,
),
StateFilter.freeze({EventTypes.Create: None}, include_others=False),
)
# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
StateFilter.freeze(
{EventTypes.Member: {"@wombat:spqr"}},
include_others=False,
),
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
)
# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)
# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{EventTypes.CanonicalAlias: {""}},
include_others=False,
),
)
# (specific state keys) - (specific state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
)
# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
)
def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only a has the include_others flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Create: None,
EventTypes.Member: set(),
EventTypes.CanonicalAlias: set(),
},
include_others=True,
),
)
# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
# This also shows that the resultant state filter is normalised.
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=True),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
EventTypes.Create: {""},
},
include_others=False,
),
StateFilter(types=frozendict(), include_others=True),
)
# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter(
types=frozendict(),
include_others=True,
),
)
# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.CanonicalAlias: {""},
EventTypes.Member: set(),
},
include_others=True,
),
)
# (specific state keys) - (specific state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
)
# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
)
def test_state_filter_difference_include_other_minus_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b have the include_others
flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=True,
),
StateFilter(types=frozendict(), include_others=False),
)
# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=True),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=False,
),
)
# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)
# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter(
types=frozendict(),
include_others=False,
),
)
# (specific state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
EventTypes.Create: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
EventTypes.Create: set(),
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
EventTypes.Create: {""},
},
include_others=False,
),
)
# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
},
include_others=False,
),
)
def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only b has the include_others flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=True,
),
StateFilter(types=frozendict(), include_others=False),
)
# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
StateFilter.freeze(
{EventTypes.Member: {"@wombat:spqr"}},
include_others=True,
),
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
)
# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)
# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter(
types=frozendict(),
include_others=False,
),
)
# (specific state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
},
include_others=False,
),
)
# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
},
include_others=False,
),
)
def test_state_filter_difference_simple_cases(self) -> None:
"""
Tests some very simple cases of the StateFilter approx_difference,
that are not explicitly tested by the more in-depth tests.
"""
self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())
self.assert_difference(
StateFilter.all(),
StateFilter.none(),
StateFilter.all(),
)
class StateFilterTestCase(TestCase):
def test_return_expanded(self) -> None:
"""
Tests the behaviour of the return_expanded() function that expands
StateFilters to include more state types (for the sake of cache hit rate).
"""
self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())
self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())
# Concrete-only state filters stay the same
# (Case: mixed filter)
self.assertEqual(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
"some.other.state.type": {""},
},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
"some.other.state.type": {""},
},
include_others=False,
),
)
# Concrete-only state filters stay the same
# (Case: non-member-only filter)
self.assertEqual(
StateFilter.freeze(
{"some.other.state.type": {""}}, include_others=False
).return_expanded(),
StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
)
# Concrete-only state filters stay the same
# (Case: member-only filter)
self.assertEqual(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
},
include_others=False,
),
)
# Wildcard member-only state filters stay the same
self.assertEqual(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)
# If there is a wildcard in the non-member portion of the filter,
# it's expanded to include ALL non-member events.
# (Case: mixed filter)
self.assertEqual(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
"some.other.state.type": None,
},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{EventTypes.Member: {"@wombat:test", "@alicia:test"}},
include_others=True,
),
)
# If there is a wildcard in the non-member portion of the filter,
# it's expanded to include ALL non-member events.
# (Case: non-member-only filter)
self.assertEqual(
StateFilter.freeze(
{
"some.other.state.type": None,
},
include_others=False,
).return_expanded(),
StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
)
self.assertEqual(
StateFilter.freeze(
{
"some.other.state.type": None,
"yet.another.state.type": {"wombat"},
},
include_others=False,
).return_expanded(),
StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
)

0
tests/types/__init__.py Normal file
View file

627
tests/types/test_state.py Normal file
View file

@ -0,0 +1,627 @@
from frozendict import frozendict
from synapse.api.constants import EventTypes
from synapse.types.state import StateFilter
from tests.unittest import TestCase
class StateFilterDifferenceTestCase(TestCase):
def assert_difference(
self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
) -> None:
self.assertEqual(
minuend.approx_difference(subtrahend),
expected,
f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
)
def test_state_filter_difference_no_include_other_minus_no_include_other(
self,
) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b do not have the
include_others flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=False,
),
StateFilter.freeze({EventTypes.Create: None}, include_others=False),
)
# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
StateFilter.freeze(
{EventTypes.Member: {"@wombat:spqr"}},
include_others=False,
),
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
)
# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)
# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{EventTypes.CanonicalAlias: {""}},
include_others=False,
),
)
# (specific state keys) - (specific state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
)
# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
)
def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only a has the include_others flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Create: None,
EventTypes.Member: set(),
EventTypes.CanonicalAlias: set(),
},
include_others=True,
),
)
# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
# This also shows that the resultant state filter is normalised.
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=True),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
EventTypes.Create: {""},
},
include_others=False,
),
StateFilter(types=frozendict(), include_others=True),
)
# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter(
types=frozendict(),
include_others=True,
),
)
# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.CanonicalAlias: {""},
EventTypes.Member: set(),
},
include_others=True,
),
)
# (specific state keys) - (specific state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
)
# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
)
def test_state_filter_difference_include_other_minus_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b have the include_others
flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=True,
),
StateFilter(types=frozendict(), include_others=False),
)
# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=True),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=False,
),
)
# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)
# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter(
types=frozendict(),
include_others=False,
),
)
# (specific state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
EventTypes.Create: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
EventTypes.Create: set(),
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
EventTypes.Create: {""},
},
include_others=False,
),
)
# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
},
include_others=False,
),
)
def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only b has the include_others flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=True,
),
StateFilter(types=frozendict(), include_others=False),
)
# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
StateFilter.freeze(
{EventTypes.Member: {"@wombat:spqr"}},
include_others=True,
),
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
)
# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)
# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter(
types=frozendict(),
include_others=False,
),
)
# (specific state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
},
include_others=False,
),
)
# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
},
include_others=False,
),
)
def test_state_filter_difference_simple_cases(self) -> None:
"""
Tests some very simple cases of the StateFilter approx_difference,
that are not explicitly tested by the more in-depth tests.
"""
self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())
self.assert_difference(
StateFilter.all(),
StateFilter.none(),
StateFilter.all(),
)
class StateFilterTestCase(TestCase):
def test_return_expanded(self) -> None:
"""
Tests the behaviour of the return_expanded() function that expands
StateFilters to include more state types (for the sake of cache hit rate).
"""
self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())
self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())
# Concrete-only state filters stay the same
# (Case: mixed filter)
self.assertEqual(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
"some.other.state.type": {""},
},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
"some.other.state.type": {""},
},
include_others=False,
),
)
# Concrete-only state filters stay the same
# (Case: non-member-only filter)
self.assertEqual(
StateFilter.freeze(
{"some.other.state.type": {""}}, include_others=False
).return_expanded(),
StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
)
# Concrete-only state filters stay the same
# (Case: member-only filter)
self.assertEqual(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
},
include_others=False,
),
)
# Wildcard member-only state filters stay the same
self.assertEqual(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)
# If there is a wildcard in the non-member portion of the filter,
# it's expanded to include ALL non-member events.
# (Case: mixed filter)
self.assertEqual(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
"some.other.state.type": None,
},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{EventTypes.Member: {"@wombat:test", "@alicia:test"}},
include_others=True,
),
)
# If there is a wildcard in the non-member portion of the filter,
# it's expanded to include ALL non-member events.
# (Case: non-member-only filter)
self.assertEqual(
StateFilter.freeze(
{
"some.other.state.type": None,
},
include_others=False,
).return_expanded(),
StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
)
self.assertEqual(
StateFilter.freeze(
{
"some.other.state.type": None,
"yet.another.state.type": {"wombat"},
},
include_others=False,
).return_expanded(),
StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
)