mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-18 08:54:54 +03:00
Allow spam-checker modules to be provide async methods. (#8890)
Spam checker modules can now provide async methods. This is implemented in a backwards-compatible manner.
This commit is contained in:
parent
5d34f40d49
commit
f14428b25c
19 changed files with 98 additions and 73 deletions
1
changelog.d/8890.feature
Normal file
1
changelog.d/8890.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Spam-checkers may now define their methods as `async`.
|
|
@ -22,6 +22,8 @@ well as some specific methods:
|
||||||
* `user_may_create_room`
|
* `user_may_create_room`
|
||||||
* `user_may_create_room_alias`
|
* `user_may_create_room_alias`
|
||||||
* `user_may_publish_room`
|
* `user_may_publish_room`
|
||||||
|
* `check_username_for_spam`
|
||||||
|
* `check_registration_for_spam`
|
||||||
|
|
||||||
The details of the each of these methods (as well as their inputs and outputs)
|
The details of the each of these methods (as well as their inputs and outputs)
|
||||||
are documented in the `synapse.events.spamcheck.SpamChecker` class.
|
are documented in the `synapse.events.spamcheck.SpamChecker` class.
|
||||||
|
@ -32,28 +34,33 @@ call back into the homeserver internals.
|
||||||
### Example
|
### Example
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
from synapse.spam_checker_api import RegistrationBehaviour
|
||||||
|
|
||||||
class ExampleSpamChecker:
|
class ExampleSpamChecker:
|
||||||
def __init__(self, config, api):
|
def __init__(self, config, api):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.api = api
|
self.api = api
|
||||||
|
|
||||||
def check_event_for_spam(self, foo):
|
async def check_event_for_spam(self, foo):
|
||||||
return False # allow all events
|
return False # allow all events
|
||||||
|
|
||||||
def user_may_invite(self, inviter_userid, invitee_userid, room_id):
|
async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
|
||||||
return True # allow all invites
|
return True # allow all invites
|
||||||
|
|
||||||
def user_may_create_room(self, userid):
|
async def user_may_create_room(self, userid):
|
||||||
return True # allow all room creations
|
return True # allow all room creations
|
||||||
|
|
||||||
def user_may_create_room_alias(self, userid, room_alias):
|
async def user_may_create_room_alias(self, userid, room_alias):
|
||||||
return True # allow all room aliases
|
return True # allow all room aliases
|
||||||
|
|
||||||
def user_may_publish_room(self, userid, room_id):
|
async def user_may_publish_room(self, userid, room_id):
|
||||||
return True # allow publishing of all rooms
|
return True # allow publishing of all rooms
|
||||||
|
|
||||||
def check_username_for_spam(self, user_profile):
|
async def check_username_for_spam(self, user_profile):
|
||||||
return False # allow all usernames
|
return False # allow all usernames
|
||||||
|
|
||||||
|
async def check_registration_for_spam(self, email_threepid, username, request_info):
|
||||||
|
return RegistrationBehaviour.ALLOW # allow all registrations
|
||||||
```
|
```
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
|
@ -15,10 +15,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from synapse.spam_checker_api import RegistrationBehaviour
|
from synapse.spam_checker_api import RegistrationBehaviour
|
||||||
from synapse.types import Collection
|
from synapse.types import Collection
|
||||||
|
from synapse.util.async_helpers import maybe_awaitable
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import synapse.events
|
import synapse.events
|
||||||
|
@ -39,7 +40,9 @@ class SpamChecker:
|
||||||
else:
|
else:
|
||||||
self.spam_checkers.append(module(config=config))
|
self.spam_checkers.append(module(config=config))
|
||||||
|
|
||||||
def check_event_for_spam(self, event: "synapse.events.EventBase") -> bool:
|
async def check_event_for_spam(
|
||||||
|
self, event: "synapse.events.EventBase"
|
||||||
|
) -> Union[bool, str]:
|
||||||
"""Checks if a given event is considered "spammy" by this server.
|
"""Checks if a given event is considered "spammy" by this server.
|
||||||
|
|
||||||
If the server considers an event spammy, then it will be rejected if
|
If the server considers an event spammy, then it will be rejected if
|
||||||
|
@ -50,15 +53,16 @@ class SpamChecker:
|
||||||
event: the event to be checked
|
event: the event to be checked
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if the event is spammy.
|
True or a string if the event is spammy. If a string is returned it
|
||||||
|
will be used as the error message returned to the user.
|
||||||
"""
|
"""
|
||||||
for spam_checker in self.spam_checkers:
|
for spam_checker in self.spam_checkers:
|
||||||
if spam_checker.check_event_for_spam(event):
|
if await maybe_awaitable(spam_checker.check_event_for_spam(event)):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def user_may_invite(
|
async def user_may_invite(
|
||||||
self, inviter_userid: str, invitee_userid: str, room_id: str
|
self, inviter_userid: str, invitee_userid: str, room_id: str
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Checks if a given user may send an invite
|
"""Checks if a given user may send an invite
|
||||||
|
@ -75,14 +79,18 @@ class SpamChecker:
|
||||||
"""
|
"""
|
||||||
for spam_checker in self.spam_checkers:
|
for spam_checker in self.spam_checkers:
|
||||||
if (
|
if (
|
||||||
spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
|
await maybe_awaitable(
|
||||||
|
spam_checker.user_may_invite(
|
||||||
|
inviter_userid, invitee_userid, room_id
|
||||||
|
)
|
||||||
|
)
|
||||||
is False
|
is False
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def user_may_create_room(self, userid: str) -> bool:
|
async def user_may_create_room(self, userid: str) -> bool:
|
||||||
"""Checks if a given user may create a room
|
"""Checks if a given user may create a room
|
||||||
|
|
||||||
If this method returns false, the creation request will be rejected.
|
If this method returns false, the creation request will be rejected.
|
||||||
|
@ -94,12 +102,15 @@ class SpamChecker:
|
||||||
True if the user may create a room, otherwise False
|
True if the user may create a room, otherwise False
|
||||||
"""
|
"""
|
||||||
for spam_checker in self.spam_checkers:
|
for spam_checker in self.spam_checkers:
|
||||||
if spam_checker.user_may_create_room(userid) is False:
|
if (
|
||||||
|
await maybe_awaitable(spam_checker.user_may_create_room(userid))
|
||||||
|
is False
|
||||||
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
|
async def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
|
||||||
"""Checks if a given user may create a room alias
|
"""Checks if a given user may create a room alias
|
||||||
|
|
||||||
If this method returns false, the association request will be rejected.
|
If this method returns false, the association request will be rejected.
|
||||||
|
@ -112,12 +123,17 @@ class SpamChecker:
|
||||||
True if the user may create a room alias, otherwise False
|
True if the user may create a room alias, otherwise False
|
||||||
"""
|
"""
|
||||||
for spam_checker in self.spam_checkers:
|
for spam_checker in self.spam_checkers:
|
||||||
if spam_checker.user_may_create_room_alias(userid, room_alias) is False:
|
if (
|
||||||
|
await maybe_awaitable(
|
||||||
|
spam_checker.user_may_create_room_alias(userid, room_alias)
|
||||||
|
)
|
||||||
|
is False
|
||||||
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def user_may_publish_room(self, userid: str, room_id: str) -> bool:
|
async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
|
||||||
"""Checks if a given user may publish a room to the directory
|
"""Checks if a given user may publish a room to the directory
|
||||||
|
|
||||||
If this method returns false, the publish request will be rejected.
|
If this method returns false, the publish request will be rejected.
|
||||||
|
@ -130,12 +146,17 @@ class SpamChecker:
|
||||||
True if the user may publish the room, otherwise False
|
True if the user may publish the room, otherwise False
|
||||||
"""
|
"""
|
||||||
for spam_checker in self.spam_checkers:
|
for spam_checker in self.spam_checkers:
|
||||||
if spam_checker.user_may_publish_room(userid, room_id) is False:
|
if (
|
||||||
|
await maybe_awaitable(
|
||||||
|
spam_checker.user_may_publish_room(userid, room_id)
|
||||||
|
)
|
||||||
|
is False
|
||||||
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
|
async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
|
||||||
"""Checks if a user ID or display name are considered "spammy" by this server.
|
"""Checks if a user ID or display name are considered "spammy" by this server.
|
||||||
|
|
||||||
If the server considers a username spammy, then it will not be included in
|
If the server considers a username spammy, then it will not be included in
|
||||||
|
@ -157,12 +178,12 @@ class SpamChecker:
|
||||||
if checker:
|
if checker:
|
||||||
# Make a copy of the user profile object to ensure the spam checker
|
# Make a copy of the user profile object to ensure the spam checker
|
||||||
# cannot modify it.
|
# cannot modify it.
|
||||||
if checker(user_profile.copy()):
|
if await maybe_awaitable(checker(user_profile.copy())):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def check_registration_for_spam(
|
async def check_registration_for_spam(
|
||||||
self,
|
self,
|
||||||
email_threepid: Optional[dict],
|
email_threepid: Optional[dict],
|
||||||
username: Optional[str],
|
username: Optional[str],
|
||||||
|
@ -185,7 +206,9 @@ class SpamChecker:
|
||||||
# spam checker
|
# spam checker
|
||||||
checker = getattr(spam_checker, "check_registration_for_spam", None)
|
checker = getattr(spam_checker, "check_registration_for_spam", None)
|
||||||
if checker:
|
if checker:
|
||||||
behaviour = checker(email_threepid, username, request_info)
|
behaviour = await maybe_awaitable(
|
||||||
|
checker(email_threepid, username, request_info)
|
||||||
|
)
|
||||||
assert isinstance(behaviour, RegistrationBehaviour)
|
assert isinstance(behaviour, RegistrationBehaviour)
|
||||||
if behaviour != RegistrationBehaviour.ALLOW:
|
if behaviour != RegistrationBehaviour.ALLOW:
|
||||||
return behaviour
|
return behaviour
|
||||||
|
|
|
@ -78,6 +78,7 @@ class FederationBase:
|
||||||
|
|
||||||
ctx = current_context()
|
ctx = current_context()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def callback(_, pdu: EventBase):
|
def callback(_, pdu: EventBase):
|
||||||
with PreserveLoggingContext(ctx):
|
with PreserveLoggingContext(ctx):
|
||||||
if not check_event_content_hash(pdu):
|
if not check_event_content_hash(pdu):
|
||||||
|
@ -105,7 +106,11 @@ class FederationBase:
|
||||||
)
|
)
|
||||||
return redacted_event
|
return redacted_event
|
||||||
|
|
||||||
if self.spam_checker.check_event_for_spam(pdu):
|
result = yield defer.ensureDeferred(
|
||||||
|
self.spam_checker.check_event_for_spam(pdu)
|
||||||
|
)
|
||||||
|
|
||||||
|
if result:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Event contains spam, redacting %s: %s",
|
"Event contains spam, redacting %s: %s",
|
||||||
pdu.event_id,
|
pdu.event_id,
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import inspect
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import unicodedata
|
import unicodedata
|
||||||
|
@ -59,6 +58,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.module_api import ModuleApi
|
from synapse.module_api import ModuleApi
|
||||||
from synapse.types import JsonDict, Requester, UserID
|
from synapse.types import JsonDict, Requester, UserID
|
||||||
from synapse.util import stringutils as stringutils
|
from synapse.util import stringutils as stringutils
|
||||||
|
from synapse.util.async_helpers import maybe_awaitable
|
||||||
from synapse.util.msisdn import phone_number_to_msisdn
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
from synapse.util.threepids import canonicalise_email
|
from synapse.util.threepids import canonicalise_email
|
||||||
|
|
||||||
|
@ -1639,6 +1639,6 @@ class PasswordProvider:
|
||||||
|
|
||||||
# This might return an awaitable, if it does block the log out
|
# This might return an awaitable, if it does block the log out
|
||||||
# until it completes.
|
# until it completes.
|
||||||
result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
|
await maybe_awaitable(
|
||||||
if inspect.isawaitable(result):
|
g(user_id=user_id, device_id=device_id, access_token=access_token,)
|
||||||
await result
|
)
|
||||||
|
|
|
@ -133,7 +133,9 @@ class DirectoryHandler(BaseHandler):
|
||||||
403, "You must be in the room to create an alias for it"
|
403, "You must be in the room to create an alias for it"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
|
if not await self.spam_checker.user_may_create_room_alias(
|
||||||
|
user_id, room_alias
|
||||||
|
):
|
||||||
raise AuthError(403, "This user is not permitted to create this alias")
|
raise AuthError(403, "This user is not permitted to create this alias")
|
||||||
|
|
||||||
if not self.config.is_alias_creation_allowed(
|
if not self.config.is_alias_creation_allowed(
|
||||||
|
@ -409,7 +411,7 @@ class DirectoryHandler(BaseHandler):
|
||||||
"""
|
"""
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
if not self.spam_checker.user_may_publish_room(user_id, room_id):
|
if not await self.spam_checker.user_may_publish_room(user_id, room_id):
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403, "This user is not permitted to publish rooms to the room list"
|
403, "This user is not permitted to publish rooms to the room list"
|
||||||
)
|
)
|
||||||
|
|
|
@ -1593,7 +1593,7 @@ class FederationHandler(BaseHandler):
|
||||||
if self.hs.config.block_non_admin_invites:
|
if self.hs.config.block_non_admin_invites:
|
||||||
raise SynapseError(403, "This server does not accept room invites")
|
raise SynapseError(403, "This server does not accept room invites")
|
||||||
|
|
||||||
if not self.spam_checker.user_may_invite(
|
if not await self.spam_checker.user_may_invite(
|
||||||
event.sender, event.state_key, event.room_id
|
event.sender, event.state_key, event.room_id
|
||||||
):
|
):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
|
|
|
@ -744,7 +744,7 @@ class EventCreationHandler:
|
||||||
event.sender,
|
event.sender,
|
||||||
)
|
)
|
||||||
|
|
||||||
spam_error = self.spam_checker.check_event_for_spam(event)
|
spam_error = await self.spam_checker.check_event_for_spam(event)
|
||||||
if spam_error:
|
if spam_error:
|
||||||
if not isinstance(spam_error, str):
|
if not isinstance(spam_error, str):
|
||||||
spam_error = "Spam is not permitted here"
|
spam_error = "Spam is not permitted here"
|
||||||
|
|
|
@ -18,7 +18,6 @@ from typing import List, Tuple
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.handlers._base import BaseHandler
|
from synapse.handlers._base import BaseHandler
|
||||||
from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
|
from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
|
||||||
from synapse.util.async_helpers import maybe_awaitable
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -98,11 +97,9 @@ class ReceiptsHandler(BaseHandler):
|
||||||
|
|
||||||
self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
|
self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
|
||||||
# Note that the min here shouldn't be relied upon to be accurate.
|
# Note that the min here shouldn't be relied upon to be accurate.
|
||||||
await maybe_awaitable(
|
await self.hs.get_pusherpool().on_new_receipts(
|
||||||
self.hs.get_pusherpool().on_new_receipts(
|
|
||||||
min_batch_id, max_batch_id, affected_room_ids
|
min_batch_id, max_batch_id, affected_room_ids
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
@ -187,7 +187,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
"""
|
"""
|
||||||
self.check_registration_ratelimit(address)
|
self.check_registration_ratelimit(address)
|
||||||
|
|
||||||
result = self.spam_checker.check_registration_for_spam(
|
result = await self.spam_checker.check_registration_for_spam(
|
||||||
threepid, localpart, user_agent_ips or [],
|
threepid, localpart, user_agent_ips or [],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -358,7 +358,7 @@ class RoomCreationHandler(BaseHandler):
|
||||||
"""
|
"""
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
if not self.spam_checker.user_may_create_room(user_id):
|
if not await self.spam_checker.user_may_create_room(user_id):
|
||||||
raise SynapseError(403, "You are not permitted to create rooms")
|
raise SynapseError(403, "You are not permitted to create rooms")
|
||||||
|
|
||||||
creation_content = {
|
creation_content = {
|
||||||
|
@ -609,7 +609,7 @@ class RoomCreationHandler(BaseHandler):
|
||||||
403, "You are not permitted to create rooms", Codes.FORBIDDEN
|
403, "You are not permitted to create rooms", Codes.FORBIDDEN
|
||||||
)
|
)
|
||||||
|
|
||||||
if not is_requester_admin and not self.spam_checker.user_may_create_room(
|
if not is_requester_admin and not await self.spam_checker.user_may_create_room(
|
||||||
user_id
|
user_id
|
||||||
):
|
):
|
||||||
raise SynapseError(403, "You are not permitted to create rooms")
|
raise SynapseError(403, "You are not permitted to create rooms")
|
||||||
|
|
|
@ -408,7 +408,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||||
)
|
)
|
||||||
block_invite = True
|
block_invite = True
|
||||||
|
|
||||||
if not self.spam_checker.user_may_invite(
|
if not await self.spam_checker.user_may_invite(
|
||||||
requester.user.to_string(), target.to_string(), room_id
|
requester.user.to_string(), target.to_string(), room_id
|
||||||
):
|
):
|
||||||
logger.info("Blocking invite due to spam checker")
|
logger.info("Blocking invite due to spam checker")
|
||||||
|
|
|
@ -81,11 +81,11 @@ class UserDirectoryHandler(StateDeltasHandler):
|
||||||
results = await self.store.search_user_dir(user_id, search_term, limit)
|
results = await self.store.search_user_dir(user_id, search_term, limit)
|
||||||
|
|
||||||
# Remove any spammy users from the results.
|
# Remove any spammy users from the results.
|
||||||
results["results"] = [
|
non_spammy_users = []
|
||||||
user
|
for user in results["results"]:
|
||||||
for user in results["results"]
|
if not await self.spam_checker.check_username_for_spam(user):
|
||||||
if not self.spam_checker.check_username_for_spam(user)
|
non_spammy_users.append(user)
|
||||||
]
|
results["results"] = non_spammy_users
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import inspect
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
@ -25,6 +24,7 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.logging.context import LoggingContext, PreserveLoggingContext
|
from synapse.logging.context import LoggingContext, PreserveLoggingContext
|
||||||
from synapse.logging.opentracing import noop_context_manager, start_active_span
|
from synapse.logging.opentracing import noop_context_manager, start_active_span
|
||||||
|
from synapse.util.async_helpers import maybe_awaitable
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import resource
|
import resource
|
||||||
|
@ -206,12 +206,7 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
|
||||||
if bg_start_span:
|
if bg_start_span:
|
||||||
ctx = start_active_span(desc, tags={"request_id": context.request})
|
ctx = start_active_span(desc, tags={"request_id": context.request})
|
||||||
with ctx:
|
with ctx:
|
||||||
result = func(*args, **kwargs)
|
return await maybe_awaitable(func(*args, **kwargs))
|
||||||
|
|
||||||
if inspect.isawaitable(result):
|
|
||||||
result = await result
|
|
||||||
|
|
||||||
return result
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Background process '%s' threw an exception", desc,
|
"Background process '%s' threw an exception", desc,
|
||||||
|
|
|
@ -13,7 +13,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import inspect
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
@ -21,6 +20,7 @@ from typing import Optional
|
||||||
|
|
||||||
from synapse.config._base import Config
|
from synapse.config._base import Config
|
||||||
from synapse.logging.context import defer_to_thread, run_in_background
|
from synapse.logging.context import defer_to_thread, run_in_background
|
||||||
|
from synapse.util.async_helpers import maybe_awaitable
|
||||||
|
|
||||||
from ._base import FileInfo, Responder
|
from ._base import FileInfo, Responder
|
||||||
from .media_storage import FileResponder
|
from .media_storage import FileResponder
|
||||||
|
@ -91,16 +91,14 @@ class StorageProviderWrapper(StorageProvider):
|
||||||
if self.store_synchronous:
|
if self.store_synchronous:
|
||||||
# store_file is supposed to return an Awaitable, but guard
|
# store_file is supposed to return an Awaitable, but guard
|
||||||
# against improper implementations.
|
# against improper implementations.
|
||||||
result = self.backend.store_file(path, file_info)
|
return await maybe_awaitable(self.backend.store_file(path, file_info))
|
||||||
if inspect.isawaitable(result):
|
|
||||||
return await result
|
|
||||||
else:
|
else:
|
||||||
# TODO: Handle errors.
|
# TODO: Handle errors.
|
||||||
async def store():
|
async def store():
|
||||||
try:
|
try:
|
||||||
result = self.backend.store_file(path, file_info)
|
return await maybe_awaitable(
|
||||||
if inspect.isawaitable(result):
|
self.backend.store_file(path, file_info)
|
||||||
return await result
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error storing file")
|
logger.exception("Error storing file")
|
||||||
|
|
||||||
|
@ -110,9 +108,7 @@ class StorageProviderWrapper(StorageProvider):
|
||||||
async def fetch(self, path, file_info):
|
async def fetch(self, path, file_info):
|
||||||
# store_file is supposed to return an Awaitable, but guard
|
# store_file is supposed to return an Awaitable, but guard
|
||||||
# against improper implementations.
|
# against improper implementations.
|
||||||
result = self.backend.fetch(path, file_info)
|
return await maybe_awaitable(self.backend.fetch(path, file_info))
|
||||||
if inspect.isawaitable(result):
|
|
||||||
return await result
|
|
||||||
|
|
||||||
|
|
||||||
class FileStorageProviderBackend(StorageProvider):
|
class FileStorageProviderBackend(StorageProvider):
|
||||||
|
|
|
@ -618,7 +618,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
return StatsHandler(self)
|
return StatsHandler(self)
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_spam_checker(self):
|
def get_spam_checker(self) -> SpamChecker:
|
||||||
return SpamChecker(self)
|
return SpamChecker(self)
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
|
|
|
@ -15,10 +15,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Hashable,
|
Hashable,
|
||||||
|
@ -542,11 +544,11 @@ class DoneAwaitable:
|
||||||
raise StopIteration(self.value)
|
raise StopIteration(self.value)
|
||||||
|
|
||||||
|
|
||||||
def maybe_awaitable(value):
|
def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
|
||||||
"""Convert a value to an awaitable if not already an awaitable.
|
"""Convert a value to an awaitable if not already an awaitable.
|
||||||
"""
|
"""
|
||||||
|
if inspect.isawaitable(value):
|
||||||
if hasattr(value, "__await__"):
|
assert isinstance(value, Awaitable)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
return DoneAwaitable(value)
|
return DoneAwaitable(value)
|
||||||
|
|
|
@ -12,13 +12,13 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import inspect
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
from synapse.util.async_helpers import maybe_awaitable
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -105,10 +105,7 @@ class Signal:
|
||||||
|
|
||||||
async def do(observer):
|
async def do(observer):
|
||||||
try:
|
try:
|
||||||
result = observer(*args, **kwargs)
|
return await maybe_awaitable(observer(*args, **kwargs))
|
||||||
if inspect.isawaitable(result):
|
|
||||||
result = await result
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"%s signal observer %s failed: %r", self.name, observer, e,
|
"%s signal observer %s failed: %r", self.name, observer, e,
|
||||||
|
|
|
@ -270,7 +270,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
spam_checker = self.hs.get_spam_checker()
|
spam_checker = self.hs.get_spam_checker()
|
||||||
|
|
||||||
class AllowAll:
|
class AllowAll:
|
||||||
def check_username_for_spam(self, user_profile):
|
async def check_username_for_spam(self, user_profile):
|
||||||
# Allow all users.
|
# Allow all users.
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -283,7 +283,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Configure a spam checker that filters all users.
|
# Configure a spam checker that filters all users.
|
||||||
class BlockAll:
|
class BlockAll:
|
||||||
def check_username_for_spam(self, user_profile):
|
async def check_username_for_spam(self, user_profile):
|
||||||
# All users are spammy.
|
# All users are spammy.
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue