mirror of
https://github.com/element-hq/synapse.git
synced 2024-11-22 01:25:44 +03:00
Allow selecting "prejoin" events by state keys (#14642)
* Declare new config * Parse new config * Read new config * Don't use trial/our TestCase where it's not needed Before: ``` $ time trial tests/events/test_utils.py > /dev/null real 0m2.277s user 0m2.186s sys 0m0.083s ``` After: ``` $ time trial tests/events/test_utils.py > /dev/null real 0m0.566s user 0m0.508s sys 0m0.056s ``` * Helper to upsert to event fields without exceeding size limits. * Use helper when adding invite/knock state Now that we allow admins to include events in prejoin room state with arbitrary state keys, be a good Matrix citizen and ensure they don't accidentally create an oversized event. * Changelog * Move StateFilter tests should have done this in #14668 * Add extra methods to StateFilter * Use StateFilter * Ensure test file enforces typed defs; alphabetise * Workaround surprising get_current_state_ids * Whoops, fix mypy
This commit is contained in:
parent
3d87847ecc
commit
e2a1adbf5d
14 changed files with 982 additions and 694 deletions
1
changelog.d/14642.feature
Normal file
1
changelog.d/14642.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Allow selecting "prejoin" events by state keys in addition to event types.
|
|
@ -2501,32 +2501,53 @@ Config settings related to the client/server API
|
||||||
---
|
---
|
||||||
### `room_prejoin_state`
|
### `room_prejoin_state`
|
||||||
|
|
||||||
Controls for the state that is shared with users who receive an invite
|
This setting controls the state that is shared with users upon receiving an
|
||||||
to a room. By default, the following state event types are shared with users who
|
invite to a room, or in reply to a knock on a room. By default, the following
|
||||||
receive invites to the room:
|
state events are shared with users:
|
||||||
- m.room.join_rules
|
|
||||||
- m.room.canonical_alias
|
- `m.room.join_rules`
|
||||||
- m.room.avatar
|
- `m.room.canonical_alias`
|
||||||
- m.room.encryption
|
- `m.room.avatar`
|
||||||
- m.room.name
|
- `m.room.encryption`
|
||||||
- m.room.create
|
- `m.room.name`
|
||||||
- m.room.topic
|
- `m.room.create`
|
||||||
|
- `m.room.topic`
|
||||||
|
|
||||||
To change the default behavior, use the following sub-options:
|
To change the default behavior, use the following sub-options:
|
||||||
* `disable_default_event_types`: set to true to disable the above defaults. If this
|
* `disable_default_event_types`: boolean. Set to `true` to disable the above
|
||||||
is enabled, only the event types listed in `additional_event_types` are shared.
|
defaults. If this is enabled, only the event types listed in
|
||||||
Defaults to false.
|
`additional_event_types` are shared. Defaults to `false`.
|
||||||
* `additional_event_types`: Additional state event types to share with users when they are invited
|
* `additional_event_types`: A list of additional state events to include in the
|
||||||
to a room. By default, this list is empty (so only the default event types are shared).
|
events to be shared. By default, this list is empty (so only the default event
|
||||||
|
types are shared).
|
||||||
|
|
||||||
|
Each entry in this list should be either a single string or a list of two
|
||||||
|
strings.
|
||||||
|
* A standalone string `t` represents all events with type `t` (i.e.
|
||||||
|
with no restrictions on state keys).
|
||||||
|
* A pair of strings `[t, s]` represents a single event with type `t` and
|
||||||
|
state key `s`. The same type can appear in two entries with different state
|
||||||
|
keys: in this situation, both state keys are included in prejoin state.
|
||||||
|
|
||||||
Example configuration:
|
Example configuration:
|
||||||
```yaml
|
```yaml
|
||||||
room_prejoin_state:
|
room_prejoin_state:
|
||||||
disable_default_event_types: true
|
disable_default_event_types: false
|
||||||
additional_event_types:
|
additional_event_types:
|
||||||
- org.example.custom.event.type
|
# Share all events of type `org.example.custom.event.typeA`
|
||||||
- m.room.join_rules
|
- org.example.custom.event.typeA
|
||||||
|
# Share only events of type `org.example.custom.event.typeB` whose
|
||||||
|
# state_key is "foo"
|
||||||
|
- ["org.example.custom.event.typeB", "foo"]
|
||||||
|
# Share only events of type `org.example.custom.event.typeC` whose
|
||||||
|
# state_key is "bar" or "baz"
|
||||||
|
- ["org.example.custom.event.typeC", "bar"]
|
||||||
|
- ["org.example.custom.event.typeC", "baz"]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
*Changed in Synapse 1.74:* admins can filter the events in prejoin state based
|
||||||
|
on their state key.
|
||||||
|
|
||||||
---
|
---
|
||||||
### `track_puppeted_user_ips`
|
### `track_puppeted_user_ips`
|
||||||
|
|
||||||
|
|
12
mypy.ini
12
mypy.ini
|
@ -89,6 +89,12 @@ disallow_untyped_defs = False
|
||||||
[mypy-tests.*]
|
[mypy-tests.*]
|
||||||
disallow_untyped_defs = False
|
disallow_untyped_defs = False
|
||||||
|
|
||||||
|
[mypy-tests.config.test_api]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-tests.federation.transport.test_client]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-tests.handlers.test_sso]
|
[mypy-tests.handlers.test_sso]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
@ -101,7 +107,7 @@ disallow_untyped_defs = True
|
||||||
[mypy-tests.push.test_bulk_push_rule_evaluator]
|
[mypy-tests.push.test_bulk_push_rule_evaluator]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-tests.test_server]
|
[mypy-tests.rest.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-tests.state.test_profile]
|
[mypy-tests.state.test_profile]
|
||||||
|
@ -110,10 +116,10 @@ disallow_untyped_defs = True
|
||||||
[mypy-tests.storage.*]
|
[mypy-tests.storage.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-tests.rest.*]
|
[mypy-tests.test_server]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-tests.federation.transport.test_client]
|
[mypy-tests.types.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-tests.util.caches.*]
|
[mypy-tests.util.caches.*]
|
||||||
|
|
|
@ -33,6 +33,9 @@ def validate_config(
|
||||||
config: the configuration value to be validated
|
config: the configuration value to be validated
|
||||||
config_path: the path within the config file. This will be used as a basis
|
config_path: the path within the config file. This will be used as a basis
|
||||||
for the error message.
|
for the error message.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ConfigError, if validation fails.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
jsonschema.validate(config, json_schema)
|
jsonschema.validate(config, json_schema)
|
||||||
|
|
|
@ -13,12 +13,13 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Iterable
|
from typing import Any, Iterable, Optional, Tuple
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.config._base import Config, ConfigError
|
from synapse.config._base import Config, ConfigError
|
||||||
from synapse.config._util import validate_config
|
from synapse.config._util import validate_config
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
from synapse.types.state import StateFilter
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -26,16 +27,20 @@ logger = logging.getLogger(__name__)
|
||||||
class ApiConfig(Config):
|
class ApiConfig(Config):
|
||||||
section = "api"
|
section = "api"
|
||||||
|
|
||||||
|
room_prejoin_state: StateFilter
|
||||||
|
track_puppetted_users_ips: bool
|
||||||
|
|
||||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||||
validate_config(_MAIN_SCHEMA, config, ())
|
validate_config(_MAIN_SCHEMA, config, ())
|
||||||
self.room_prejoin_state = list(self._get_prejoin_state_types(config))
|
self.room_prejoin_state = StateFilter.from_types(
|
||||||
|
self._get_prejoin_state_entries(config)
|
||||||
|
)
|
||||||
self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False)
|
self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False)
|
||||||
|
|
||||||
def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]:
|
def _get_prejoin_state_entries(
|
||||||
"""Get the event types to include in the prejoin state
|
self, config: JsonDict
|
||||||
|
) -> Iterable[Tuple[str, Optional[str]]]:
|
||||||
Parses the config and returns an iterable of the event types to be included.
|
"""Get the event types and state keys to include in the prejoin state."""
|
||||||
"""
|
|
||||||
room_prejoin_state_config = config.get("room_prejoin_state") or {}
|
room_prejoin_state_config = config.get("room_prejoin_state") or {}
|
||||||
|
|
||||||
# backwards-compatibility support for room_invite_state_types
|
# backwards-compatibility support for room_invite_state_types
|
||||||
|
@ -50,33 +55,39 @@ class ApiConfig(Config):
|
||||||
|
|
||||||
logger.warning(_ROOM_INVITE_STATE_TYPES_WARNING)
|
logger.warning(_ROOM_INVITE_STATE_TYPES_WARNING)
|
||||||
|
|
||||||
yield from config["room_invite_state_types"]
|
for event_type in config["room_invite_state_types"]:
|
||||||
|
yield event_type, None
|
||||||
return
|
return
|
||||||
|
|
||||||
if not room_prejoin_state_config.get("disable_default_event_types"):
|
if not room_prejoin_state_config.get("disable_default_event_types"):
|
||||||
yield from _DEFAULT_PREJOIN_STATE_TYPES
|
yield from _DEFAULT_PREJOIN_STATE_TYPES_AND_STATE_KEYS
|
||||||
|
|
||||||
yield from room_prejoin_state_config.get("additional_event_types", [])
|
for entry in room_prejoin_state_config.get("additional_event_types", []):
|
||||||
|
if isinstance(entry, str):
|
||||||
|
yield entry, None
|
||||||
|
else:
|
||||||
|
yield entry
|
||||||
|
|
||||||
|
|
||||||
_ROOM_INVITE_STATE_TYPES_WARNING = """\
|
_ROOM_INVITE_STATE_TYPES_WARNING = """\
|
||||||
WARNING: The 'room_invite_state_types' configuration setting is now deprecated,
|
WARNING: The 'room_invite_state_types' configuration setting is now deprecated,
|
||||||
and replaced with 'room_prejoin_state'. New features may not work correctly
|
and replaced with 'room_prejoin_state'. New features may not work correctly
|
||||||
unless 'room_invite_state_types' is removed. See the sample configuration file for
|
unless 'room_invite_state_types' is removed. See the config documentation at
|
||||||
details of 'room_prejoin_state'.
|
https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#room_prejoin_state
|
||||||
|
for details of 'room_prejoin_state'.
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_DEFAULT_PREJOIN_STATE_TYPES = [
|
_DEFAULT_PREJOIN_STATE_TYPES_AND_STATE_KEYS = [
|
||||||
EventTypes.JoinRules,
|
(EventTypes.JoinRules, ""),
|
||||||
EventTypes.CanonicalAlias,
|
(EventTypes.CanonicalAlias, ""),
|
||||||
EventTypes.RoomAvatar,
|
(EventTypes.RoomAvatar, ""),
|
||||||
EventTypes.RoomEncryption,
|
(EventTypes.RoomEncryption, ""),
|
||||||
EventTypes.Name,
|
(EventTypes.Name, ""),
|
||||||
# Per MSC1772.
|
# Per MSC1772.
|
||||||
EventTypes.Create,
|
(EventTypes.Create, ""),
|
||||||
# Per MSC3173.
|
# Per MSC3173.
|
||||||
EventTypes.Topic,
|
(EventTypes.Topic, ""),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -89,8 +100,18 @@ _ROOM_PREJOIN_STATE_CONFIG_SCHEMA = {
|
||||||
"properties": {
|
"properties": {
|
||||||
"disable_default_event_types": {"type": "boolean"},
|
"disable_default_event_types": {"type": "boolean"},
|
||||||
"additional_event_types": {
|
"additional_event_types": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"oneOf": [
|
||||||
|
{"type": "string"},
|
||||||
|
{
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {"type": "string"},
|
"items": {"type": "string"},
|
||||||
|
"minItems": 2,
|
||||||
|
"maxItems": 2,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
|
@ -28,8 +28,14 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
from canonicaljson import encode_canonical_json
|
||||||
|
|
||||||
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
|
from synapse.api.constants import (
|
||||||
|
MAX_PDU_SIZE,
|
||||||
|
EventContentFields,
|
||||||
|
EventTypes,
|
||||||
|
RelationTypes,
|
||||||
|
)
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.api.room_versions import RoomVersion
|
from synapse.api.room_versions import RoomVersion
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
@ -674,3 +680,27 @@ def validate_canonicaljson(value: Any) -> None:
|
||||||
elif not isinstance(value, (bool, str)) and value is not None:
|
elif not isinstance(value, (bool, str)) and value is not None:
|
||||||
# Other potential JSON values (bool, None, str) are safe.
|
# Other potential JSON values (bool, None, str) are safe.
|
||||||
raise SynapseError(400, "Unknown JSON value", Codes.BAD_JSON)
|
raise SynapseError(400, "Unknown JSON value", Codes.BAD_JSON)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_upsert_event_field(
|
||||||
|
event: EventBase, container: JsonDict, key: str, value: object
|
||||||
|
) -> bool:
|
||||||
|
"""Upsert an event field, but only if this doesn't make the event too large.
|
||||||
|
|
||||||
|
Returns true iff the upsert took place.
|
||||||
|
"""
|
||||||
|
if key in container:
|
||||||
|
old_value: object = container[key]
|
||||||
|
container[key] = value
|
||||||
|
# NB: here and below, we assume that passing a non-None `time_now` argument to
|
||||||
|
# get_pdu_json doesn't increase the size of the encoded result.
|
||||||
|
upsert_okay = len(encode_canonical_json(event.get_pdu_json())) <= MAX_PDU_SIZE
|
||||||
|
if not upsert_okay:
|
||||||
|
container[key] = old_value
|
||||||
|
else:
|
||||||
|
container[key] = value
|
||||||
|
upsert_okay = len(encode_canonical_json(event.get_pdu_json())) <= MAX_PDU_SIZE
|
||||||
|
if not upsert_okay:
|
||||||
|
del container[key]
|
||||||
|
|
||||||
|
return upsert_okay
|
||||||
|
|
|
@ -50,6 +50,7 @@ from synapse.event_auth import validate_event_for_room_version
|
||||||
from synapse.events import EventBase, relation_from_event
|
from synapse.events import EventBase, relation_from_event
|
||||||
from synapse.events.builder import EventBuilder
|
from synapse.events.builder import EventBuilder
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
|
from synapse.events.utils import maybe_upsert_event_field
|
||||||
from synapse.events.validator import EventValidator
|
from synapse.events.validator import EventValidator
|
||||||
from synapse.handlers.directory import DirectoryHandler
|
from synapse.handlers.directory import DirectoryHandler
|
||||||
from synapse.logging import opentracing
|
from synapse.logging import opentracing
|
||||||
|
@ -1739,12 +1740,15 @@ class EventCreationHandler:
|
||||||
|
|
||||||
if event.type == EventTypes.Member:
|
if event.type == EventTypes.Member:
|
||||||
if event.content["membership"] == Membership.INVITE:
|
if event.content["membership"] == Membership.INVITE:
|
||||||
event.unsigned[
|
maybe_upsert_event_field(
|
||||||
"invite_room_state"
|
event,
|
||||||
] = await self.store.get_stripped_room_state_from_event_context(
|
event.unsigned,
|
||||||
|
"invite_room_state",
|
||||||
|
await self.store.get_stripped_room_state_from_event_context(
|
||||||
context,
|
context,
|
||||||
self.room_prejoin_state_types,
|
self.room_prejoin_state_types,
|
||||||
membership_user_id=event.sender,
|
membership_user_id=event.sender,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
invitee = UserID.from_string(event.state_key)
|
invitee = UserID.from_string(event.state_key)
|
||||||
|
@ -1762,11 +1766,14 @@ class EventCreationHandler:
|
||||||
event.signatures.update(returned_invite.signatures)
|
event.signatures.update(returned_invite.signatures)
|
||||||
|
|
||||||
if event.content["membership"] == Membership.KNOCK:
|
if event.content["membership"] == Membership.KNOCK:
|
||||||
event.unsigned[
|
maybe_upsert_event_field(
|
||||||
"knock_room_state"
|
event,
|
||||||
] = await self.store.get_stripped_room_state_from_event_context(
|
event.unsigned,
|
||||||
|
"knock_room_state",
|
||||||
|
await self.store.get_stripped_room_state_from_event_context(
|
||||||
context,
|
context,
|
||||||
self.room_prejoin_state_types,
|
self.room_prejoin_state_types,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.type == EventTypes.Redaction:
|
if event.type == EventTypes.Redaction:
|
||||||
|
|
|
@ -16,11 +16,11 @@ import logging
|
||||||
import threading
|
import threading
|
||||||
import weakref
|
import weakref
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
|
from itertools import chain
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Collection,
|
Collection,
|
||||||
Container,
|
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
|
@ -76,6 +76,7 @@ from synapse.storage.util.id_generators import (
|
||||||
)
|
)
|
||||||
from synapse.storage.util.sequence import build_sequence_generator
|
from synapse.storage.util.sequence import build_sequence_generator
|
||||||
from synapse.types import JsonDict, get_domain_from_id
|
from synapse.types import JsonDict, get_domain_from_id
|
||||||
|
from synapse.types.state import StateFilter
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
|
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
@ -879,7 +880,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
async def get_stripped_room_state_from_event_context(
|
async def get_stripped_room_state_from_event_context(
|
||||||
self,
|
self,
|
||||||
context: EventContext,
|
context: EventContext,
|
||||||
state_types_to_include: Container[str],
|
state_keys_to_include: StateFilter,
|
||||||
membership_user_id: Optional[str] = None,
|
membership_user_id: Optional[str] = None,
|
||||||
) -> List[JsonDict]:
|
) -> List[JsonDict]:
|
||||||
"""
|
"""
|
||||||
|
@ -892,7 +893,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
context: The event context to retrieve state of the room from.
|
context: The event context to retrieve state of the room from.
|
||||||
state_types_to_include: The type of state events to include.
|
state_keys_to_include: The state events to include, for each event type.
|
||||||
membership_user_id: An optional user ID to include the stripped membership state
|
membership_user_id: An optional user ID to include the stripped membership state
|
||||||
events of. This is useful when generating the stripped state of a room for
|
events of. This is useful when generating the stripped state of a room for
|
||||||
invites. We want to send membership events of the inviter, so that the
|
invites. We want to send membership events of the inviter, so that the
|
||||||
|
@ -901,21 +902,25 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
Returns:
|
Returns:
|
||||||
A list of dictionaries, each representing a stripped state event from the room.
|
A list of dictionaries, each representing a stripped state event from the room.
|
||||||
"""
|
"""
|
||||||
current_state_ids = await context.get_current_state_ids()
|
if membership_user_id:
|
||||||
|
types = chain(
|
||||||
|
state_keys_to_include.to_types(),
|
||||||
|
[(EventTypes.Member, membership_user_id)],
|
||||||
|
)
|
||||||
|
filter = StateFilter.from_types(types)
|
||||||
|
else:
|
||||||
|
filter = state_keys_to_include
|
||||||
|
selected_state_ids = await context.get_current_state_ids(filter)
|
||||||
|
|
||||||
# We know this event is not an outlier, so this must be
|
# We know this event is not an outlier, so this must be
|
||||||
# non-None.
|
# non-None.
|
||||||
assert current_state_ids is not None
|
assert selected_state_ids is not None
|
||||||
|
|
||||||
# The state to include
|
# Confusingly, get_current_state_events may return events that are discarded by
|
||||||
state_to_include_ids = [
|
# the filter, if they're in context._state_delta_due_to_event. Strip these away.
|
||||||
e_id
|
selected_state_ids = filter.filter_state(selected_state_ids)
|
||||||
for k, e_id in current_state_ids.items()
|
|
||||||
if k[0] in state_types_to_include
|
|
||||||
or (membership_user_id and k == (EventTypes.Member, membership_user_id))
|
|
||||||
]
|
|
||||||
|
|
||||||
state_to_include = await self.get_events(state_to_include_ids)
|
state_to_include = await self.get_events(selected_state_ids.values())
|
||||||
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
|
|
|
@ -118,6 +118,15 @@ class StateFilter:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_types(self) -> Iterable[Tuple[str, Optional[str]]]:
|
||||||
|
"""The inverse to `from_types`."""
|
||||||
|
for (event_type, state_keys) in self.types.items():
|
||||||
|
if state_keys is None:
|
||||||
|
yield event_type, None
|
||||||
|
else:
|
||||||
|
for state_key in state_keys:
|
||||||
|
yield event_type, state_key
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
|
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
|
||||||
"""Creates a filter that returns all non-member events, plus the member
|
"""Creates a filter that returns all non-member events, plus the member
|
||||||
|
@ -343,6 +352,15 @@ class StateFilter:
|
||||||
for s in state_keys
|
for s in state_keys
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def wildcard_types(self) -> List[str]:
|
||||||
|
"""Returns a list of event types which require us to fetch all state keys.
|
||||||
|
This will be empty unless `has_wildcards` returns True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of event types.
|
||||||
|
"""
|
||||||
|
return [t for t, state_keys in self.types.items() if state_keys is None]
|
||||||
|
|
||||||
def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
|
def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
|
||||||
"""Return the filter split into two: one which assumes it's exclusively
|
"""Return the filter split into two: one which assumes it's exclusively
|
||||||
matching against member state, and one which assumes it's matching
|
matching against member state, and one which assumes it's matching
|
||||||
|
|
145
tests/config/test_api.py
Normal file
145
tests/config/test_api.py
Normal file
|
@ -0,0 +1,145 @@
|
||||||
|
from unittest import TestCase as StdlibTestCase
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from synapse.config import ConfigError
|
||||||
|
from synapse.config.api import ApiConfig
|
||||||
|
from synapse.types.state import StateFilter
|
||||||
|
|
||||||
|
DEFAULT_PREJOIN_STATE_PAIRS = {
|
||||||
|
("m.room.join_rules", ""),
|
||||||
|
("m.room.canonical_alias", ""),
|
||||||
|
("m.room.avatar", ""),
|
||||||
|
("m.room.encryption", ""),
|
||||||
|
("m.room.name", ""),
|
||||||
|
("m.room.create", ""),
|
||||||
|
("m.room.topic", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestRoomPrejoinState(StdlibTestCase):
|
||||||
|
def read_config(self, source: str) -> ApiConfig:
|
||||||
|
config = ApiConfig()
|
||||||
|
config.read_config(yaml.safe_load(source))
|
||||||
|
return config
|
||||||
|
|
||||||
|
def test_no_prejoin_state(self) -> None:
|
||||||
|
config = self.read_config("foo: bar")
|
||||||
|
self.assertFalse(config.room_prejoin_state.has_wildcards())
|
||||||
|
self.assertEqual(
|
||||||
|
set(config.room_prejoin_state.concrete_types()), DEFAULT_PREJOIN_STATE_PAIRS
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_disable_default_event_types(self) -> None:
|
||||||
|
config = self.read_config(
|
||||||
|
"""
|
||||||
|
room_prejoin_state:
|
||||||
|
disable_default_event_types: true
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self.assertEqual(config.room_prejoin_state, StateFilter.none())
|
||||||
|
|
||||||
|
def test_event_without_state_key(self) -> None:
|
||||||
|
config = self.read_config(
|
||||||
|
"""
|
||||||
|
room_prejoin_state:
|
||||||
|
disable_default_event_types: true
|
||||||
|
additional_event_types:
|
||||||
|
- foo
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
|
||||||
|
self.assertEqual(config.room_prejoin_state.concrete_types(), [])
|
||||||
|
|
||||||
|
def test_event_with_specific_state_key(self) -> None:
|
||||||
|
config = self.read_config(
|
||||||
|
"""
|
||||||
|
room_prejoin_state:
|
||||||
|
disable_default_event_types: true
|
||||||
|
additional_event_types:
|
||||||
|
- [foo, bar]
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self.assertFalse(config.room_prejoin_state.has_wildcards())
|
||||||
|
self.assertEqual(
|
||||||
|
set(config.room_prejoin_state.concrete_types()),
|
||||||
|
{("foo", "bar")},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_repeated_event_with_specific_state_key(self) -> None:
|
||||||
|
config = self.read_config(
|
||||||
|
"""
|
||||||
|
room_prejoin_state:
|
||||||
|
disable_default_event_types: true
|
||||||
|
additional_event_types:
|
||||||
|
- [foo, bar]
|
||||||
|
- [foo, baz]
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self.assertFalse(config.room_prejoin_state.has_wildcards())
|
||||||
|
self.assertEqual(
|
||||||
|
set(config.room_prejoin_state.concrete_types()),
|
||||||
|
{("foo", "bar"), ("foo", "baz")},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_no_specific_state_key_overrides_specific_state_key(self) -> None:
|
||||||
|
config = self.read_config(
|
||||||
|
"""
|
||||||
|
room_prejoin_state:
|
||||||
|
disable_default_event_types: true
|
||||||
|
additional_event_types:
|
||||||
|
- [foo, bar]
|
||||||
|
- foo
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
|
||||||
|
self.assertEqual(config.room_prejoin_state.concrete_types(), [])
|
||||||
|
|
||||||
|
config = self.read_config(
|
||||||
|
"""
|
||||||
|
room_prejoin_state:
|
||||||
|
disable_default_event_types: true
|
||||||
|
additional_event_types:
|
||||||
|
- foo
|
||||||
|
- [foo, bar]
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
|
||||||
|
self.assertEqual(config.room_prejoin_state.concrete_types(), [])
|
||||||
|
|
||||||
|
def test_bad_event_type_entry_raises(self) -> None:
|
||||||
|
with self.assertRaises(ConfigError):
|
||||||
|
self.read_config(
|
||||||
|
"""
|
||||||
|
room_prejoin_state:
|
||||||
|
additional_event_types:
|
||||||
|
- []
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaises(ConfigError):
|
||||||
|
self.read_config(
|
||||||
|
"""
|
||||||
|
room_prejoin_state:
|
||||||
|
additional_event_types:
|
||||||
|
- [a]
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaises(ConfigError):
|
||||||
|
self.read_config(
|
||||||
|
"""
|
||||||
|
room_prejoin_state:
|
||||||
|
additional_event_types:
|
||||||
|
- [a, b, c]
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaises(ConfigError):
|
||||||
|
self.read_config(
|
||||||
|
"""
|
||||||
|
room_prejoin_state:
|
||||||
|
additional_event_types:
|
||||||
|
- [true, 1.23]
|
||||||
|
"""
|
||||||
|
)
|
|
@ -12,19 +12,20 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest as stdlib_unittest
|
||||||
|
|
||||||
from synapse.api.constants import EventContentFields
|
from synapse.api.constants import EventContentFields
|
||||||
from synapse.api.room_versions import RoomVersions
|
from synapse.api.room_versions import RoomVersions
|
||||||
from synapse.events import make_event_from_dict
|
from synapse.events import make_event_from_dict
|
||||||
from synapse.events.utils import (
|
from synapse.events.utils import (
|
||||||
SerializeEventConfig,
|
SerializeEventConfig,
|
||||||
copy_and_fixup_power_levels_contents,
|
copy_and_fixup_power_levels_contents,
|
||||||
|
maybe_upsert_event_field,
|
||||||
prune_event,
|
prune_event,
|
||||||
serialize_event,
|
serialize_event,
|
||||||
)
|
)
|
||||||
from synapse.util.frozenutils import freeze
|
from synapse.util.frozenutils import freeze
|
||||||
|
|
||||||
from tests import unittest
|
|
||||||
|
|
||||||
|
|
||||||
def MockEvent(**kwargs):
|
def MockEvent(**kwargs):
|
||||||
if "event_id" not in kwargs:
|
if "event_id" not in kwargs:
|
||||||
|
@ -34,7 +35,31 @@ def MockEvent(**kwargs):
|
||||||
return make_event_from_dict(kwargs)
|
return make_event_from_dict(kwargs)
|
||||||
|
|
||||||
|
|
||||||
class PruneEventTestCase(unittest.TestCase):
|
class TestMaybeUpsertEventField(stdlib_unittest.TestCase):
|
||||||
|
def test_update_okay(self) -> None:
|
||||||
|
event = make_event_from_dict({"event_id": "$1234"})
|
||||||
|
success = maybe_upsert_event_field(event, event.unsigned, "key", "value")
|
||||||
|
self.assertTrue(success)
|
||||||
|
self.assertEqual(event.unsigned["key"], "value")
|
||||||
|
|
||||||
|
def test_update_not_okay(self) -> None:
|
||||||
|
event = make_event_from_dict({"event_id": "$1234"})
|
||||||
|
LARGE_STRING = "a" * 100_000
|
||||||
|
success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING)
|
||||||
|
self.assertFalse(success)
|
||||||
|
self.assertNotIn("key", event.unsigned)
|
||||||
|
|
||||||
|
def test_update_not_okay_leaves_original_value(self) -> None:
|
||||||
|
event = make_event_from_dict(
|
||||||
|
{"event_id": "$1234", "unsigned": {"key": "value"}}
|
||||||
|
)
|
||||||
|
LARGE_STRING = "a" * 100_000
|
||||||
|
success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING)
|
||||||
|
self.assertFalse(success)
|
||||||
|
self.assertEqual(event.unsigned["key"], "value")
|
||||||
|
|
||||||
|
|
||||||
|
class PruneEventTestCase(stdlib_unittest.TestCase):
|
||||||
def run_test(self, evdict, matchdict, **kwargs):
|
def run_test(self, evdict, matchdict, **kwargs):
|
||||||
"""
|
"""
|
||||||
Asserts that a new event constructed with `evdict` will look like
|
Asserts that a new event constructed with `evdict` will look like
|
||||||
|
@ -391,7 +416,7 @@ class PruneEventTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SerializeEventTestCase(unittest.TestCase):
|
class SerializeEventTestCase(stdlib_unittest.TestCase):
|
||||||
def serialize(self, ev, fields):
|
def serialize(self, ev, fields):
|
||||||
return serialize_event(
|
return serialize_event(
|
||||||
ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields)
|
ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields)
|
||||||
|
@ -513,7 +538,7 @@ class SerializeEventTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CopyPowerLevelsContentTestCase(unittest.TestCase):
|
class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.test_content = {
|
self.test_content = {
|
||||||
"ban": 50,
|
"ban": 50,
|
||||||
|
|
|
@ -26,7 +26,7 @@ from synapse.types import JsonDict, RoomID, StateMap, UserID
|
||||||
from synapse.types.state import StateFilter
|
from synapse.types.state import StateFilter
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase, TestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -494,624 +494,3 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual(is_all, True)
|
self.assertEqual(is_all, True)
|
||||||
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
|
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
|
||||||
|
|
||||||
|
|
||||||
class StateFilterDifferenceTestCase(TestCase):
|
|
||||||
def assert_difference(
|
|
||||||
self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
|
|
||||||
) -> None:
|
|
||||||
self.assertEqual(
|
|
||||||
minuend.approx_difference(subtrahend),
|
|
||||||
expected,
|
|
||||||
f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_state_filter_difference_no_include_other_minus_no_include_other(
|
|
||||||
self,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Tests the StateFilter.approx_difference method
|
|
||||||
where, in a.approx_difference(b), both a and b do not have the
|
|
||||||
include_others flag set.
|
|
||||||
"""
|
|
||||||
# (wildcard on state keys) - (wildcard on state keys):
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: None, EventTypes.Create: None},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
StateFilter.freeze({EventTypes.Create: None}, include_others=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
# (wildcard on state keys) - (specific state keys)
|
|
||||||
# This one is an over-approximation because we can't represent
|
|
||||||
# 'all state keys except a few named examples'
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: {"@wombat:spqr"}},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
# (wildcard on state keys) - (no state keys)
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: None},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: set(),
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: None},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# (specific state keys) - (wildcard on state keys):
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
|
||||||
EventTypes.CanonicalAlias: {""},
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: None},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.CanonicalAlias: {""}},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# (specific state keys) - (specific state keys)
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
|
||||||
EventTypes.CanonicalAlias: {""},
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:spqr"},
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@spqr:spqr"},
|
|
||||||
EventTypes.CanonicalAlias: {""},
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# (specific state keys) - (no state keys)
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
|
||||||
EventTypes.CanonicalAlias: {""},
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: set(),
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
|
||||||
EventTypes.CanonicalAlias: {""},
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
|
|
||||||
"""
|
|
||||||
Tests the StateFilter.approx_difference method
|
|
||||||
where, in a.approx_difference(b), only a has the include_others flag set.
|
|
||||||
"""
|
|
||||||
# (wildcard on state keys) - (wildcard on state keys):
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: None, EventTypes.Create: None},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Create: None,
|
|
||||||
EventTypes.Member: set(),
|
|
||||||
EventTypes.CanonicalAlias: set(),
|
|
||||||
},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# (wildcard on state keys) - (specific state keys)
|
|
||||||
# This one is an over-approximation because we can't represent
|
|
||||||
# 'all state keys except a few named examples'
|
|
||||||
# This also shows that the resultant state filter is normalised.
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.freeze({EventTypes.Member: None}, include_others=True),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:spqr"},
|
|
||||||
EventTypes.Create: {""},
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
StateFilter(types=frozendict(), include_others=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
# (wildcard on state keys) - (no state keys)
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: None},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: set(),
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
StateFilter(
|
|
||||||
types=frozendict(),
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# (specific state keys) - (wildcard on state keys):
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
|
||||||
EventTypes.CanonicalAlias: {""},
|
|
||||||
},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: None},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.CanonicalAlias: {""},
|
|
||||||
EventTypes.Member: set(),
|
|
||||||
},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# (specific state keys) - (specific state keys)
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
|
||||||
EventTypes.CanonicalAlias: {""},
|
|
||||||
},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:spqr"},
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@spqr:spqr"},
|
|
||||||
EventTypes.CanonicalAlias: {""},
|
|
||||||
},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# (specific state keys) - (no state keys)
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
|
||||||
EventTypes.CanonicalAlias: {""},
|
|
||||||
},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: set(),
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
|
||||||
EventTypes.CanonicalAlias: {""},
|
|
||||||
},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_state_filter_difference_include_other_minus_include_other(self) -> None:
|
|
||||||
"""
|
|
||||||
Tests the StateFilter.approx_difference method
|
|
||||||
where, in a.approx_difference(b), both a and b have the include_others
|
|
||||||
flag set.
|
|
||||||
"""
|
|
||||||
# (wildcard on state keys) - (wildcard on state keys):
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: None, EventTypes.Create: None},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
StateFilter(types=frozendict(), include_others=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
# (wildcard on state keys) - (specific state keys)
|
|
||||||
# This one is an over-approximation because we can't represent
|
|
||||||
# 'all state keys except a few named examples'
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.freeze({EventTypes.Member: None}, include_others=True),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:spqr"},
|
|
||||||
EventTypes.CanonicalAlias: {""},
|
|
||||||
},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# (wildcard on state keys) - (no state keys)
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: None},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: set(),
|
|
||||||
},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: None},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# (specific state keys) - (wildcard on state keys):
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
|
||||||
EventTypes.CanonicalAlias: {""},
|
|
||||||
},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: None},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
StateFilter(
|
|
||||||
types=frozendict(),
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# (specific state keys) - (specific state keys)
|
|
||||||
# This one is an over-approximation because we can't represent
|
|
||||||
# 'all state keys except a few named examples'
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
|
||||||
EventTypes.CanonicalAlias: {""},
|
|
||||||
EventTypes.Create: {""},
|
|
||||||
},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:spqr"},
|
|
||||||
EventTypes.Create: set(),
|
|
||||||
},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@spqr:spqr"},
|
|
||||||
EventTypes.Create: {""},
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# (specific state keys) - (no state keys)
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
|
||||||
EventTypes.CanonicalAlias: {""},
|
|
||||||
},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: set(),
|
|
||||||
},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
|
|
||||||
"""
|
|
||||||
Tests the StateFilter.approx_difference method
|
|
||||||
where, in a.approx_difference(b), only b has the include_others flag set.
|
|
||||||
"""
|
|
||||||
# (wildcard on state keys) - (wildcard on state keys):
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: None, EventTypes.Create: None},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
StateFilter(types=frozendict(), include_others=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
# (wildcard on state keys) - (specific state keys)
|
|
||||||
# This one is an over-approximation because we can't represent
|
|
||||||
# 'all state keys except a few named examples'
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: {"@wombat:spqr"}},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
# (wildcard on state keys) - (no state keys)
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: None},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: set(),
|
|
||||||
},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: None},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# (specific state keys) - (wildcard on state keys):
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
|
||||||
EventTypes.CanonicalAlias: {""},
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: None},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
StateFilter(
|
|
||||||
types=frozendict(),
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# (specific state keys) - (specific state keys)
|
|
||||||
# This one is an over-approximation because we can't represent
|
|
||||||
# 'all state keys except a few named examples'
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
|
||||||
EventTypes.CanonicalAlias: {""},
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:spqr"},
|
|
||||||
},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@spqr:spqr"},
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# (specific state keys) - (no state keys)
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
|
||||||
EventTypes.CanonicalAlias: {""},
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: set(),
|
|
||||||
},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_state_filter_difference_simple_cases(self) -> None:
|
|
||||||
"""
|
|
||||||
Tests some very simple cases of the StateFilter approx_difference,
|
|
||||||
that are not explicitly tested by the more in-depth tests.
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())
|
|
||||||
|
|
||||||
self.assert_difference(
|
|
||||||
StateFilter.all(),
|
|
||||||
StateFilter.none(),
|
|
||||||
StateFilter.all(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StateFilterTestCase(TestCase):
|
|
||||||
def test_return_expanded(self) -> None:
|
|
||||||
"""
|
|
||||||
Tests the behaviour of the return_expanded() function that expands
|
|
||||||
StateFilters to include more state types (for the sake of cache hit rate).
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())
|
|
||||||
|
|
||||||
self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())
|
|
||||||
|
|
||||||
# Concrete-only state filters stay the same
|
|
||||||
# (Case: mixed filter)
|
|
||||||
self.assertEqual(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:test", "@alicia:test"},
|
|
||||||
"some.other.state.type": {""},
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
).return_expanded(),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:test", "@alicia:test"},
|
|
||||||
"some.other.state.type": {""},
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Concrete-only state filters stay the same
|
|
||||||
# (Case: non-member-only filter)
|
|
||||||
self.assertEqual(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{"some.other.state.type": {""}}, include_others=False
|
|
||||||
).return_expanded(),
|
|
||||||
StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Concrete-only state filters stay the same
|
|
||||||
# (Case: member-only filter)
|
|
||||||
self.assertEqual(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:test", "@alicia:test"},
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
).return_expanded(),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:test", "@alicia:test"},
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wildcard member-only state filters stay the same
|
|
||||||
self.assertEqual(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: None},
|
|
||||||
include_others=False,
|
|
||||||
).return_expanded(),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: None},
|
|
||||||
include_others=False,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# If there is a wildcard in the non-member portion of the filter,
|
|
||||||
# it's expanded to include ALL non-member events.
|
|
||||||
# (Case: mixed filter)
|
|
||||||
self.assertEqual(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
EventTypes.Member: {"@wombat:test", "@alicia:test"},
|
|
||||||
"some.other.state.type": None,
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
).return_expanded(),
|
|
||||||
StateFilter.freeze(
|
|
||||||
{EventTypes.Member: {"@wombat:test", "@alicia:test"}},
|
|
||||||
include_others=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# If there is a wildcard in the non-member portion of the filter,
|
|
||||||
# it's expanded to include ALL non-member events.
|
|
||||||
# (Case: non-member-only filter)
|
|
||||||
self.assertEqual(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
"some.other.state.type": None,
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
).return_expanded(),
|
|
||||||
StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
StateFilter.freeze(
|
|
||||||
{
|
|
||||||
"some.other.state.type": None,
|
|
||||||
"yet.another.state.type": {"wombat"},
|
|
||||||
},
|
|
||||||
include_others=False,
|
|
||||||
).return_expanded(),
|
|
||||||
StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
|
|
||||||
)
|
|
||||||
|
|
0
tests/types/__init__.py
Normal file
0
tests/types/__init__.py
Normal file
627
tests/types/test_state.py
Normal file
627
tests/types/test_state.py
Normal file
|
@ -0,0 +1,627 @@
|
||||||
|
from frozendict import frozendict
|
||||||
|
|
||||||
|
from synapse.api.constants import EventTypes
|
||||||
|
from synapse.types.state import StateFilter
|
||||||
|
|
||||||
|
from tests.unittest import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class StateFilterDifferenceTestCase(TestCase):
|
||||||
|
def assert_difference(
|
||||||
|
self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
|
||||||
|
) -> None:
|
||||||
|
self.assertEqual(
|
||||||
|
minuend.approx_difference(subtrahend),
|
||||||
|
expected,
|
||||||
|
f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_state_filter_difference_no_include_other_minus_no_include_other(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Tests the StateFilter.approx_difference method
|
||||||
|
where, in a.approx_difference(b), both a and b do not have the
|
||||||
|
include_others flag set.
|
||||||
|
"""
|
||||||
|
# (wildcard on state keys) - (wildcard on state keys):
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: None, EventTypes.Create: None},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
StateFilter.freeze({EventTypes.Create: None}, include_others=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# (wildcard on state keys) - (specific state keys)
|
||||||
|
# This one is an over-approximation because we can't represent
|
||||||
|
# 'all state keys except a few named examples'
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: {"@wombat:spqr"}},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# (wildcard on state keys) - (no state keys)
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: None},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: set(),
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: None},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# (specific state keys) - (wildcard on state keys):
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||||
|
EventTypes.CanonicalAlias: {""},
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: None},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.CanonicalAlias: {""}},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# (specific state keys) - (specific state keys)
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||||
|
EventTypes.CanonicalAlias: {""},
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:spqr"},
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@spqr:spqr"},
|
||||||
|
EventTypes.CanonicalAlias: {""},
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# (specific state keys) - (no state keys)
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||||
|
EventTypes.CanonicalAlias: {""},
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: set(),
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||||
|
EventTypes.CanonicalAlias: {""},
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
|
||||||
|
"""
|
||||||
|
Tests the StateFilter.approx_difference method
|
||||||
|
where, in a.approx_difference(b), only a has the include_others flag set.
|
||||||
|
"""
|
||||||
|
# (wildcard on state keys) - (wildcard on state keys):
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: None, EventTypes.Create: None},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Create: None,
|
||||||
|
EventTypes.Member: set(),
|
||||||
|
EventTypes.CanonicalAlias: set(),
|
||||||
|
},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# (wildcard on state keys) - (specific state keys)
|
||||||
|
# This one is an over-approximation because we can't represent
|
||||||
|
# 'all state keys except a few named examples'
|
||||||
|
# This also shows that the resultant state filter is normalised.
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.freeze({EventTypes.Member: None}, include_others=True),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:spqr"},
|
||||||
|
EventTypes.Create: {""},
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
StateFilter(types=frozendict(), include_others=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# (wildcard on state keys) - (no state keys)
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: None},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: set(),
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
StateFilter(
|
||||||
|
types=frozendict(),
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# (specific state keys) - (wildcard on state keys):
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||||
|
EventTypes.CanonicalAlias: {""},
|
||||||
|
},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: None},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.CanonicalAlias: {""},
|
||||||
|
EventTypes.Member: set(),
|
||||||
|
},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# (specific state keys) - (specific state keys)
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||||
|
EventTypes.CanonicalAlias: {""},
|
||||||
|
},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:spqr"},
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@spqr:spqr"},
|
||||||
|
EventTypes.CanonicalAlias: {""},
|
||||||
|
},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# (specific state keys) - (no state keys)
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||||
|
EventTypes.CanonicalAlias: {""},
|
||||||
|
},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: set(),
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||||
|
EventTypes.CanonicalAlias: {""},
|
||||||
|
},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_state_filter_difference_include_other_minus_include_other(self) -> None:
|
||||||
|
"""
|
||||||
|
Tests the StateFilter.approx_difference method
|
||||||
|
where, in a.approx_difference(b), both a and b have the include_others
|
||||||
|
flag set.
|
||||||
|
"""
|
||||||
|
# (wildcard on state keys) - (wildcard on state keys):
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: None, EventTypes.Create: None},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
StateFilter(types=frozendict(), include_others=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# (wildcard on state keys) - (specific state keys)
|
||||||
|
# This one is an over-approximation because we can't represent
|
||||||
|
# 'all state keys except a few named examples'
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.freeze({EventTypes.Member: None}, include_others=True),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:spqr"},
|
||||||
|
EventTypes.CanonicalAlias: {""},
|
||||||
|
},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# (wildcard on state keys) - (no state keys)
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: None},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: set(),
|
||||||
|
},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: None},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# (specific state keys) - (wildcard on state keys):
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||||
|
EventTypes.CanonicalAlias: {""},
|
||||||
|
},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: None},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
StateFilter(
|
||||||
|
types=frozendict(),
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# (specific state keys) - (specific state keys)
|
||||||
|
# This one is an over-approximation because we can't represent
|
||||||
|
# 'all state keys except a few named examples'
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||||
|
EventTypes.CanonicalAlias: {""},
|
||||||
|
EventTypes.Create: {""},
|
||||||
|
},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:spqr"},
|
||||||
|
EventTypes.Create: set(),
|
||||||
|
},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@spqr:spqr"},
|
||||||
|
EventTypes.Create: {""},
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# (specific state keys) - (no state keys)
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||||
|
EventTypes.CanonicalAlias: {""},
|
||||||
|
},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: set(),
|
||||||
|
},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
|
||||||
|
"""
|
||||||
|
Tests the StateFilter.approx_difference method
|
||||||
|
where, in a.approx_difference(b), only b has the include_others flag set.
|
||||||
|
"""
|
||||||
|
# (wildcard on state keys) - (wildcard on state keys):
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: None, EventTypes.Create: None},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
StateFilter(types=frozendict(), include_others=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# (wildcard on state keys) - (specific state keys)
|
||||||
|
# This one is an over-approximation because we can't represent
|
||||||
|
# 'all state keys except a few named examples'
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: {"@wombat:spqr"}},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# (wildcard on state keys) - (no state keys)
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: None},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: set(),
|
||||||
|
},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: None},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# (specific state keys) - (wildcard on state keys):
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||||
|
EventTypes.CanonicalAlias: {""},
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: None},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
StateFilter(
|
||||||
|
types=frozendict(),
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# (specific state keys) - (specific state keys)
|
||||||
|
# This one is an over-approximation because we can't represent
|
||||||
|
# 'all state keys except a few named examples'
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||||
|
EventTypes.CanonicalAlias: {""},
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:spqr"},
|
||||||
|
},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@spqr:spqr"},
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# (specific state keys) - (no state keys)
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||||
|
EventTypes.CanonicalAlias: {""},
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: set(),
|
||||||
|
},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_state_filter_difference_simple_cases(self) -> None:
|
||||||
|
"""
|
||||||
|
Tests some very simple cases of the StateFilter approx_difference,
|
||||||
|
that are not explicitly tested by the more in-depth tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())
|
||||||
|
|
||||||
|
self.assert_difference(
|
||||||
|
StateFilter.all(),
|
||||||
|
StateFilter.none(),
|
||||||
|
StateFilter.all(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StateFilterTestCase(TestCase):
|
||||||
|
def test_return_expanded(self) -> None:
|
||||||
|
"""
|
||||||
|
Tests the behaviour of the return_expanded() function that expands
|
||||||
|
StateFilters to include more state types (for the sake of cache hit rate).
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())
|
||||||
|
|
||||||
|
self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())
|
||||||
|
|
||||||
|
# Concrete-only state filters stay the same
|
||||||
|
# (Case: mixed filter)
|
||||||
|
self.assertEqual(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:test", "@alicia:test"},
|
||||||
|
"some.other.state.type": {""},
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
).return_expanded(),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:test", "@alicia:test"},
|
||||||
|
"some.other.state.type": {""},
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Concrete-only state filters stay the same
|
||||||
|
# (Case: non-member-only filter)
|
||||||
|
self.assertEqual(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{"some.other.state.type": {""}}, include_others=False
|
||||||
|
).return_expanded(),
|
||||||
|
StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Concrete-only state filters stay the same
|
||||||
|
# (Case: member-only filter)
|
||||||
|
self.assertEqual(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:test", "@alicia:test"},
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
).return_expanded(),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:test", "@alicia:test"},
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wildcard member-only state filters stay the same
|
||||||
|
self.assertEqual(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: None},
|
||||||
|
include_others=False,
|
||||||
|
).return_expanded(),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: None},
|
||||||
|
include_others=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# If there is a wildcard in the non-member portion of the filter,
|
||||||
|
# it's expanded to include ALL non-member events.
|
||||||
|
# (Case: mixed filter)
|
||||||
|
self.assertEqual(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
EventTypes.Member: {"@wombat:test", "@alicia:test"},
|
||||||
|
"some.other.state.type": None,
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
).return_expanded(),
|
||||||
|
StateFilter.freeze(
|
||||||
|
{EventTypes.Member: {"@wombat:test", "@alicia:test"}},
|
||||||
|
include_others=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# If there is a wildcard in the non-member portion of the filter,
|
||||||
|
# it's expanded to include ALL non-member events.
|
||||||
|
# (Case: non-member-only filter)
|
||||||
|
self.assertEqual(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
"some.other.state.type": None,
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
).return_expanded(),
|
||||||
|
StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
StateFilter.freeze(
|
||||||
|
{
|
||||||
|
"some.other.state.type": None,
|
||||||
|
"yet.another.state.type": {"wombat"},
|
||||||
|
},
|
||||||
|
include_others=False,
|
||||||
|
).return_expanded(),
|
||||||
|
StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
|
||||||
|
)
|
Loading…
Reference in a new issue