Allow spam-checker modules to be provide async methods. (#8890)

Spam checker modules can now provide async methods. This is implemented
in a backwards-compatible manner.
This commit is contained in:
David Teller 2020-12-11 20:05:15 +01:00 committed by GitHub
parent 5d34f40d49
commit f14428b25c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 98 additions and 73 deletions

1
changelog.d/8890.feature Normal file
View file

@ -0,0 +1 @@
Spam-checkers may now define their methods as `async`.

View file

@ -22,6 +22,8 @@ well as some specific methods:
* `user_may_create_room` * `user_may_create_room`
* `user_may_create_room_alias` * `user_may_create_room_alias`
* `user_may_publish_room` * `user_may_publish_room`
* `check_username_for_spam`
* `check_registration_for_spam`
The details of the each of these methods (as well as their inputs and outputs) The details of the each of these methods (as well as their inputs and outputs)
are documented in the `synapse.events.spamcheck.SpamChecker` class. are documented in the `synapse.events.spamcheck.SpamChecker` class.
@ -32,28 +34,33 @@ call back into the homeserver internals.
### Example ### Example
```python ```python
from synapse.spam_checker_api import RegistrationBehaviour
class ExampleSpamChecker: class ExampleSpamChecker:
def __init__(self, config, api): def __init__(self, config, api):
self.config = config self.config = config
self.api = api self.api = api
def check_event_for_spam(self, foo): async def check_event_for_spam(self, foo):
return False # allow all events return False # allow all events
def user_may_invite(self, inviter_userid, invitee_userid, room_id): async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
return True # allow all invites return True # allow all invites
def user_may_create_room(self, userid): async def user_may_create_room(self, userid):
return True # allow all room creations return True # allow all room creations
def user_may_create_room_alias(self, userid, room_alias): async def user_may_create_room_alias(self, userid, room_alias):
return True # allow all room aliases return True # allow all room aliases
def user_may_publish_room(self, userid, room_id): async def user_may_publish_room(self, userid, room_id):
return True # allow publishing of all rooms return True # allow publishing of all rooms
def check_username_for_spam(self, user_profile): async def check_username_for_spam(self, user_profile):
return False # allow all usernames return False # allow all usernames
async def check_registration_for_spam(self, email_threepid, username, request_info):
return RegistrationBehaviour.ALLOW # allow all registrations
``` ```
## Configuration ## Configuration

View file

@ -15,10 +15,11 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from synapse.spam_checker_api import RegistrationBehaviour from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import Collection from synapse.types import Collection
from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING: if TYPE_CHECKING:
import synapse.events import synapse.events
@ -39,7 +40,9 @@ class SpamChecker:
else: else:
self.spam_checkers.append(module(config=config)) self.spam_checkers.append(module(config=config))
def check_event_for_spam(self, event: "synapse.events.EventBase") -> bool: async def check_event_for_spam(
self, event: "synapse.events.EventBase"
) -> Union[bool, str]:
"""Checks if a given event is considered "spammy" by this server. """Checks if a given event is considered "spammy" by this server.
If the server considers an event spammy, then it will be rejected if If the server considers an event spammy, then it will be rejected if
@ -50,15 +53,16 @@ class SpamChecker:
event: the event to be checked event: the event to be checked
Returns: Returns:
True if the event is spammy. True or a string if the event is spammy. If a string is returned it
will be used as the error message returned to the user.
""" """
for spam_checker in self.spam_checkers: for spam_checker in self.spam_checkers:
if spam_checker.check_event_for_spam(event): if await maybe_awaitable(spam_checker.check_event_for_spam(event)):
return True return True
return False return False
def user_may_invite( async def user_may_invite(
self, inviter_userid: str, invitee_userid: str, room_id: str self, inviter_userid: str, invitee_userid: str, room_id: str
) -> bool: ) -> bool:
"""Checks if a given user may send an invite """Checks if a given user may send an invite
@ -75,14 +79,18 @@ class SpamChecker:
""" """
for spam_checker in self.spam_checkers: for spam_checker in self.spam_checkers:
if ( if (
spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id) await maybe_awaitable(
spam_checker.user_may_invite(
inviter_userid, invitee_userid, room_id
)
)
is False is False
): ):
return False return False
return True return True
def user_may_create_room(self, userid: str) -> bool: async def user_may_create_room(self, userid: str) -> bool:
"""Checks if a given user may create a room """Checks if a given user may create a room
If this method returns false, the creation request will be rejected. If this method returns false, the creation request will be rejected.
@ -94,12 +102,15 @@ class SpamChecker:
True if the user may create a room, otherwise False True if the user may create a room, otherwise False
""" """
for spam_checker in self.spam_checkers: for spam_checker in self.spam_checkers:
if spam_checker.user_may_create_room(userid) is False: if (
await maybe_awaitable(spam_checker.user_may_create_room(userid))
is False
):
return False return False
return True return True
def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool: async def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
"""Checks if a given user may create a room alias """Checks if a given user may create a room alias
If this method returns false, the association request will be rejected. If this method returns false, the association request will be rejected.
@ -112,12 +123,17 @@ class SpamChecker:
True if the user may create a room alias, otherwise False True if the user may create a room alias, otherwise False
""" """
for spam_checker in self.spam_checkers: for spam_checker in self.spam_checkers:
if spam_checker.user_may_create_room_alias(userid, room_alias) is False: if (
await maybe_awaitable(
spam_checker.user_may_create_room_alias(userid, room_alias)
)
is False
):
return False return False
return True return True
def user_may_publish_room(self, userid: str, room_id: str) -> bool: async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
"""Checks if a given user may publish a room to the directory """Checks if a given user may publish a room to the directory
If this method returns false, the publish request will be rejected. If this method returns false, the publish request will be rejected.
@ -130,12 +146,17 @@ class SpamChecker:
True if the user may publish the room, otherwise False True if the user may publish the room, otherwise False
""" """
for spam_checker in self.spam_checkers: for spam_checker in self.spam_checkers:
if spam_checker.user_may_publish_room(userid, room_id) is False: if (
await maybe_awaitable(
spam_checker.user_may_publish_room(userid, room_id)
)
is False
):
return False return False
return True return True
def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool: async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
"""Checks if a user ID or display name are considered "spammy" by this server. """Checks if a user ID or display name are considered "spammy" by this server.
If the server considers a username spammy, then it will not be included in If the server considers a username spammy, then it will not be included in
@ -157,12 +178,12 @@ class SpamChecker:
if checker: if checker:
# Make a copy of the user profile object to ensure the spam checker # Make a copy of the user profile object to ensure the spam checker
# cannot modify it. # cannot modify it.
if checker(user_profile.copy()): if await maybe_awaitable(checker(user_profile.copy())):
return True return True
return False return False
def check_registration_for_spam( async def check_registration_for_spam(
self, self,
email_threepid: Optional[dict], email_threepid: Optional[dict],
username: Optional[str], username: Optional[str],
@ -185,7 +206,9 @@ class SpamChecker:
# spam checker # spam checker
checker = getattr(spam_checker, "check_registration_for_spam", None) checker = getattr(spam_checker, "check_registration_for_spam", None)
if checker: if checker:
behaviour = checker(email_threepid, username, request_info) behaviour = await maybe_awaitable(
checker(email_threepid, username, request_info)
)
assert isinstance(behaviour, RegistrationBehaviour) assert isinstance(behaviour, RegistrationBehaviour)
if behaviour != RegistrationBehaviour.ALLOW: if behaviour != RegistrationBehaviour.ALLOW:
return behaviour return behaviour

View file

@ -78,6 +78,7 @@ class FederationBase:
ctx = current_context() ctx = current_context()
@defer.inlineCallbacks
def callback(_, pdu: EventBase): def callback(_, pdu: EventBase):
with PreserveLoggingContext(ctx): with PreserveLoggingContext(ctx):
if not check_event_content_hash(pdu): if not check_event_content_hash(pdu):
@ -105,7 +106,11 @@ class FederationBase:
) )
return redacted_event return redacted_event
if self.spam_checker.check_event_for_spam(pdu): result = yield defer.ensureDeferred(
self.spam_checker.check_event_for_spam(pdu)
)
if result:
logger.warning( logger.warning(
"Event contains spam, redacting %s: %s", "Event contains spam, redacting %s: %s",
pdu.event_id, pdu.event_id,

View file

@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import logging import logging
import time import time
import unicodedata import unicodedata
@ -59,6 +58,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
from synapse.types import JsonDict, Requester, UserID from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils from synapse.util import stringutils as stringutils
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email from synapse.util.threepids import canonicalise_email
@ -1639,6 +1639,6 @@ class PasswordProvider:
# This might return an awaitable, if it does block the log out # This might return an awaitable, if it does block the log out
# until it completes. # until it completes.
result = g(user_id=user_id, device_id=device_id, access_token=access_token,) await maybe_awaitable(
if inspect.isawaitable(result): g(user_id=user_id, device_id=device_id, access_token=access_token,)
await result )

View file

@ -133,7 +133,9 @@ class DirectoryHandler(BaseHandler):
403, "You must be in the room to create an alias for it" 403, "You must be in the room to create an alias for it"
) )
if not self.spam_checker.user_may_create_room_alias(user_id, room_alias): if not await self.spam_checker.user_may_create_room_alias(
user_id, room_alias
):
raise AuthError(403, "This user is not permitted to create this alias") raise AuthError(403, "This user is not permitted to create this alias")
if not self.config.is_alias_creation_allowed( if not self.config.is_alias_creation_allowed(
@ -409,7 +411,7 @@ class DirectoryHandler(BaseHandler):
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
if not self.spam_checker.user_may_publish_room(user_id, room_id): if not await self.spam_checker.user_may_publish_room(user_id, room_id):
raise AuthError( raise AuthError(
403, "This user is not permitted to publish rooms to the room list" 403, "This user is not permitted to publish rooms to the room list"
) )

View file

@ -1593,7 +1593,7 @@ class FederationHandler(BaseHandler):
if self.hs.config.block_non_admin_invites: if self.hs.config.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites") raise SynapseError(403, "This server does not accept room invites")
if not self.spam_checker.user_may_invite( if not await self.spam_checker.user_may_invite(
event.sender, event.state_key, event.room_id event.sender, event.state_key, event.room_id
): ):
raise SynapseError( raise SynapseError(

View file

@ -744,7 +744,7 @@ class EventCreationHandler:
event.sender, event.sender,
) )
spam_error = self.spam_checker.check_event_for_spam(event) spam_error = await self.spam_checker.check_event_for_spam(event)
if spam_error: if spam_error:
if not isinstance(spam_error, str): if not isinstance(spam_error, str):
spam_error = "Spam is not permitted here" spam_error = "Spam is not permitted here"

View file

@ -18,7 +18,6 @@ from typing import List, Tuple
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler from synapse.handlers._base import BaseHandler
from synapse.types import JsonDict, ReadReceipt, get_domain_from_id from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -98,10 +97,8 @@ class ReceiptsHandler(BaseHandler):
self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids) self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
# Note that the min here shouldn't be relied upon to be accurate. # Note that the min here shouldn't be relied upon to be accurate.
await maybe_awaitable( await self.hs.get_pusherpool().on_new_receipts(
self.hs.get_pusherpool().on_new_receipts( min_batch_id, max_batch_id, affected_room_ids
min_batch_id, max_batch_id, affected_room_ids
)
) )
return True return True

View file

@ -187,7 +187,7 @@ class RegistrationHandler(BaseHandler):
""" """
self.check_registration_ratelimit(address) self.check_registration_ratelimit(address)
result = self.spam_checker.check_registration_for_spam( result = await self.spam_checker.check_registration_for_spam(
threepid, localpart, user_agent_ips or [], threepid, localpart, user_agent_ips or [],
) )

View file

@ -358,7 +358,7 @@ class RoomCreationHandler(BaseHandler):
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
if not self.spam_checker.user_may_create_room(user_id): if not await self.spam_checker.user_may_create_room(user_id):
raise SynapseError(403, "You are not permitted to create rooms") raise SynapseError(403, "You are not permitted to create rooms")
creation_content = { creation_content = {
@ -609,7 +609,7 @@ class RoomCreationHandler(BaseHandler):
403, "You are not permitted to create rooms", Codes.FORBIDDEN 403, "You are not permitted to create rooms", Codes.FORBIDDEN
) )
if not is_requester_admin and not self.spam_checker.user_may_create_room( if not is_requester_admin and not await self.spam_checker.user_may_create_room(
user_id user_id
): ):
raise SynapseError(403, "You are not permitted to create rooms") raise SynapseError(403, "You are not permitted to create rooms")

View file

@ -408,7 +408,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
) )
block_invite = True block_invite = True
if not self.spam_checker.user_may_invite( if not await self.spam_checker.user_may_invite(
requester.user.to_string(), target.to_string(), room_id requester.user.to_string(), target.to_string(), room_id
): ):
logger.info("Blocking invite due to spam checker") logger.info("Blocking invite due to spam checker")

View file

@ -81,11 +81,11 @@ class UserDirectoryHandler(StateDeltasHandler):
results = await self.store.search_user_dir(user_id, search_term, limit) results = await self.store.search_user_dir(user_id, search_term, limit)
# Remove any spammy users from the results. # Remove any spammy users from the results.
results["results"] = [ non_spammy_users = []
user for user in results["results"]:
for user in results["results"] if not await self.spam_checker.check_username_for_spam(user):
if not self.spam_checker.check_username_for_spam(user) non_spammy_users.append(user)
] results["results"] = non_spammy_users
return results return results

View file

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import logging import logging
import threading import threading
from functools import wraps from functools import wraps
@ -25,6 +24,7 @@ from twisted.internet import defer
from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.logging.opentracing import noop_context_manager, start_active_span from synapse.logging.opentracing import noop_context_manager, start_active_span
from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING: if TYPE_CHECKING:
import resource import resource
@ -206,12 +206,7 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
if bg_start_span: if bg_start_span:
ctx = start_active_span(desc, tags={"request_id": context.request}) ctx = start_active_span(desc, tags={"request_id": context.request})
with ctx: with ctx:
result = func(*args, **kwargs) return await maybe_awaitable(func(*args, **kwargs))
if inspect.isawaitable(result):
result = await result
return result
except Exception: except Exception:
logger.exception( logger.exception(
"Background process '%s' threw an exception", desc, "Background process '%s' threw an exception", desc,

View file

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import logging import logging
import os import os
import shutil import shutil
@ -21,6 +20,7 @@ from typing import Optional
from synapse.config._base import Config from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background from synapse.logging.context import defer_to_thread, run_in_background
from synapse.util.async_helpers import maybe_awaitable
from ._base import FileInfo, Responder from ._base import FileInfo, Responder
from .media_storage import FileResponder from .media_storage import FileResponder
@ -91,16 +91,14 @@ class StorageProviderWrapper(StorageProvider):
if self.store_synchronous: if self.store_synchronous:
# store_file is supposed to return an Awaitable, but guard # store_file is supposed to return an Awaitable, but guard
# against improper implementations. # against improper implementations.
result = self.backend.store_file(path, file_info) return await maybe_awaitable(self.backend.store_file(path, file_info))
if inspect.isawaitable(result):
return await result
else: else:
# TODO: Handle errors. # TODO: Handle errors.
async def store(): async def store():
try: try:
result = self.backend.store_file(path, file_info) return await maybe_awaitable(
if inspect.isawaitable(result): self.backend.store_file(path, file_info)
return await result )
except Exception: except Exception:
logger.exception("Error storing file") logger.exception("Error storing file")
@ -110,9 +108,7 @@ class StorageProviderWrapper(StorageProvider):
async def fetch(self, path, file_info): async def fetch(self, path, file_info):
# store_file is supposed to return an Awaitable, but guard # store_file is supposed to return an Awaitable, but guard
# against improper implementations. # against improper implementations.
result = self.backend.fetch(path, file_info) return await maybe_awaitable(self.backend.fetch(path, file_info))
if inspect.isawaitable(result):
return await result
class FileStorageProviderBackend(StorageProvider): class FileStorageProviderBackend(StorageProvider):

View file

@ -618,7 +618,7 @@ class HomeServer(metaclass=abc.ABCMeta):
return StatsHandler(self) return StatsHandler(self)
@cache_in_self @cache_in_self
def get_spam_checker(self): def get_spam_checker(self) -> SpamChecker:
return SpamChecker(self) return SpamChecker(self)
@cache_in_self @cache_in_self

View file

@ -15,10 +15,12 @@
# limitations under the License. # limitations under the License.
import collections import collections
import inspect
import logging import logging
from contextlib import contextmanager from contextlib import contextmanager
from typing import ( from typing import (
Any, Any,
Awaitable,
Callable, Callable,
Dict, Dict,
Hashable, Hashable,
@ -542,11 +544,11 @@ class DoneAwaitable:
raise StopIteration(self.value) raise StopIteration(self.value)
def maybe_awaitable(value): def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
"""Convert a value to an awaitable if not already an awaitable. """Convert a value to an awaitable if not already an awaitable.
""" """
if inspect.isawaitable(value):
if hasattr(value, "__await__"): assert isinstance(value, Awaitable)
return value return value
return DoneAwaitable(value) return DoneAwaitable(value)

View file

@ -12,13 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import logging import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -105,10 +105,7 @@ class Signal:
async def do(observer): async def do(observer):
try: try:
result = observer(*args, **kwargs) return await maybe_awaitable(observer(*args, **kwargs))
if inspect.isawaitable(result):
result = await result
return result
except Exception as e: except Exception as e:
logger.warning( logger.warning(
"%s signal observer %s failed: %r", self.name, observer, e, "%s signal observer %s failed: %r", self.name, observer, e,

View file

@ -270,7 +270,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
spam_checker = self.hs.get_spam_checker() spam_checker = self.hs.get_spam_checker()
class AllowAll: class AllowAll:
def check_username_for_spam(self, user_profile): async def check_username_for_spam(self, user_profile):
# Allow all users. # Allow all users.
return False return False
@ -283,7 +283,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Configure a spam checker that filters all users. # Configure a spam checker that filters all users.
class BlockAll: class BlockAll:
def check_username_for_spam(self, user_profile): async def check_username_for_spam(self, user_profile):
# All users are spammy. # All users are spammy.
return True return True