mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-18 08:54:54 +03:00
forward requester id to check username for spam callbacks (#17916)
This commit is contained in:
parent
483602efb2
commit
eedab12e6d
6 changed files with 70 additions and 7 deletions
1
changelog.d/17916.feature
Normal file
1
changelog.d/17916.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Module developers will have access to user id of requester when adding `check_username_for_spam` callbacks to `spam_checker_module_callbacks`. Contributed by Wilson@Pangea.chat.
|
|
@ -245,7 +245,7 @@ this callback.
|
|||
_First introduced in Synapse v1.37.0_
|
||||
|
||||
```python
|
||||
async def check_username_for_spam(user_profile: synapse.module_api.UserProfile) -> bool
|
||||
async def check_username_for_spam(user_profile: synapse.module_api.UserProfile, requester_id: str) -> bool
|
||||
```
|
||||
|
||||
Called when computing search results in the user directory. The module must return a
|
||||
|
@ -264,6 +264,8 @@ The profile is represented as a dictionary with the following keys:
|
|||
The module is given a copy of the original dictionary, so modifying it from within the
|
||||
module cannot modify a user's profile when included in user directory search results.
|
||||
|
||||
The requester_id parameter is the ID of the user that called the user directory API.
|
||||
|
||||
If multiple modules implement this callback, they will be considered in order. If a
|
||||
callback returns `False`, Synapse falls through to the next one. The value of the first
|
||||
callback that does not return `False` will be used. If this happens, Synapse will not call
|
||||
|
|
|
@ -72,8 +72,8 @@ class ExampleSpamChecker:
|
|||
async def user_may_publish_room(self, userid, room_id):
|
||||
return True # allow publishing of all rooms
|
||||
|
||||
async def check_username_for_spam(self, user_profile):
|
||||
return False # allow all usernames
|
||||
async def check_username_for_spam(self, user_profile, requester_id):
|
||||
return False # allow all usernames regardless of requester
|
||||
|
||||
async def check_registration_for_spam(
|
||||
self,
|
||||
|
|
|
@ -161,7 +161,7 @@ class UserDirectoryHandler(StateDeltasHandler):
|
|||
non_spammy_users = []
|
||||
for user in results["results"]:
|
||||
if not await self._spam_checker_module_callbacks.check_username_for_spam(
|
||||
user
|
||||
user, user_id
|
||||
):
|
||||
non_spammy_users.append(user)
|
||||
results["results"] = non_spammy_users
|
||||
|
|
|
@ -31,6 +31,7 @@ from typing import (
|
|||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
# `Literal` appears with Python 3.8.
|
||||
|
@ -168,7 +169,10 @@ USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[
|
|||
]
|
||||
],
|
||||
]
|
||||
CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[UserProfile], Awaitable[bool]]
|
||||
CHECK_USERNAME_FOR_SPAM_CALLBACK = Union[
|
||||
Callable[[UserProfile], Awaitable[bool]],
|
||||
Callable[[UserProfile, str], Awaitable[bool]],
|
||||
]
|
||||
LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[
|
||||
[
|
||||
Optional[dict],
|
||||
|
@ -716,7 +720,9 @@ class SpamCheckerModuleApiCallbacks:
|
|||
|
||||
return self.NOT_SPAM
|
||||
|
||||
async def check_username_for_spam(self, user_profile: UserProfile) -> bool:
|
||||
async def check_username_for_spam(
|
||||
self, user_profile: UserProfile, requester_id: str
|
||||
) -> bool:
|
||||
"""Checks if a user ID or display name are considered "spammy" by this server.
|
||||
|
||||
If the server considers a username spammy, then it will not be included in
|
||||
|
@ -727,15 +733,33 @@ class SpamCheckerModuleApiCallbacks:
|
|||
* user_id
|
||||
* display_name
|
||||
* avatar_url
|
||||
requester_id: The user ID of the user making the user directory search request.
|
||||
|
||||
Returns:
|
||||
True if the user is spammy.
|
||||
"""
|
||||
for callback in self._check_username_for_spam_callbacks:
|
||||
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
|
||||
checker_args = inspect.signature(callback)
|
||||
# Make a copy of the user profile object to ensure the spam checker cannot
|
||||
# modify it.
|
||||
res = await delay_cancellation(callback(user_profile.copy()))
|
||||
# Also ensure backwards compatibility with spam checker callbacks
|
||||
# that don't expect the requester_id argument.
|
||||
if len(checker_args.parameters) == 2:
|
||||
callback_with_requester_id = cast(
|
||||
Callable[[UserProfile, str], Awaitable[bool]], callback
|
||||
)
|
||||
res = await delay_cancellation(
|
||||
callback_with_requester_id(user_profile.copy(), requester_id)
|
||||
)
|
||||
else:
|
||||
callback_without_requester_id = cast(
|
||||
Callable[[UserProfile], Awaitable[bool]], callback
|
||||
)
|
||||
res = await delay_cancellation(
|
||||
callback_without_requester_id(user_profile.copy())
|
||||
)
|
||||
|
||||
if res:
|
||||
return True
|
||||
|
||||
|
|
|
@ -796,6 +796,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
|||
s = self.get_success(self.handler.search_users(u1, "user2", 10))
|
||||
self.assertEqual(len(s["results"]), 1)
|
||||
|
||||
# Kept old spam checker without `requester_id` tests for backwards compatibility.
|
||||
async def allow_all(user_profile: UserProfile) -> bool:
|
||||
# Allow all users.
|
||||
return False
|
||||
|
@ -809,6 +810,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
|||
s = self.get_success(self.handler.search_users(u1, "user2", 10))
|
||||
self.assertEqual(len(s["results"]), 1)
|
||||
|
||||
# Kept old spam checker without `requester_id` tests for backwards compatibility.
|
||||
# Configure a spam checker that filters all users.
|
||||
async def block_all(user_profile: UserProfile) -> bool:
|
||||
# All users are spammy.
|
||||
|
@ -820,6 +822,40 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
|||
s = self.get_success(self.handler.search_users(u1, "user2", 10))
|
||||
self.assertEqual(len(s["results"]), 0)
|
||||
|
||||
async def allow_all_expects_requester_id(
|
||||
user_profile: UserProfile, requester_id: str
|
||||
) -> bool:
|
||||
self.assertEqual(requester_id, u1)
|
||||
# Allow all users.
|
||||
return False
|
||||
|
||||
# Configure a spam checker that does not filter any users.
|
||||
spam_checker = self.hs.get_module_api_callbacks().spam_checker
|
||||
spam_checker._check_username_for_spam_callbacks = [
|
||||
allow_all_expects_requester_id
|
||||
]
|
||||
|
||||
# The results do not change:
|
||||
# We get one search result when searching for user2 by user1.
|
||||
s = self.get_success(self.handler.search_users(u1, "user2", 10))
|
||||
self.assertEqual(len(s["results"]), 1)
|
||||
|
||||
# Configure a spam checker that filters all users.
|
||||
async def block_all_expects_requester_id(
|
||||
user_profile: UserProfile, requester_id: str
|
||||
) -> bool:
|
||||
self.assertEqual(requester_id, u1)
|
||||
# All users are spammy.
|
||||
return True
|
||||
|
||||
spam_checker._check_username_for_spam_callbacks = [
|
||||
block_all_expects_requester_id
|
||||
]
|
||||
|
||||
# User1 now gets no search results for any of the other users.
|
||||
s = self.get_success(self.handler.search_users(u1, "user2", 10))
|
||||
self.assertEqual(len(s["results"]), 0)
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"spam_checker": {
|
||||
|
|
Loading…
Reference in a new issue